Cleanup dashboard utils and move pages to views
This commit is contained in:
parent
495ddc13f9
commit
36f8737075
26 changed files with 904 additions and 514 deletions
|
|
@ -11,11 +11,11 @@ Pages:
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from entropice.dashboard.inference_page import render_inference_page
|
from entropice.dashboard.views.inference_page import render_inference_page
|
||||||
from entropice.dashboard.model_state_page import render_model_state_page
|
from entropice.dashboard.views.model_state_page import render_model_state_page
|
||||||
from entropice.dashboard.overview_page import render_overview_page
|
from entropice.dashboard.views.overview_page import render_overview_page
|
||||||
from entropice.dashboard.training_analysis_page import render_training_analysis_page
|
from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
|
||||||
from entropice.dashboard.training_data_page import render_training_data_page
|
from entropice.dashboard.views.training_data_page import render_training_data_page
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import pydeck as pdk
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from shapely.geometry import shape
|
from shapely.geometry import shape
|
||||||
|
|
||||||
from entropice.dashboard.plots.colors import get_cmap, get_palette
|
from entropice.dashboard.utils.colors import get_cmap, get_palette
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import pydeck as pdk
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from shapely.geometry import shape
|
from shapely.geometry import shape
|
||||||
|
|
||||||
from entropice.dashboard.plots.colors import get_palette
|
from entropice.dashboard.utils.colors import get_palette
|
||||||
from entropice.dashboard.utils.data import TrainingResult
|
from entropice.dashboard.utils.data import TrainingResult
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import streamlit as st
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from shapely.geometry import shape
|
from shapely.geometry import shape
|
||||||
|
|
||||||
from entropice.dashboard.plots.colors import get_cmap
|
from entropice.dashboard.utils.colors import get_cmap
|
||||||
|
|
||||||
# TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations
|
# TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import pydeck as pdk
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from shapely.geometry import shape
|
from shapely.geometry import shape
|
||||||
|
|
||||||
from entropice.dashboard.plots.colors import get_palette
|
from entropice.dashboard.utils.colors import get_palette
|
||||||
from entropice.ml.dataset import CategoricalTrainingDataset
|
from entropice.ml.dataset import CategoricalTrainingDataset
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -132,7 +132,7 @@ def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, m
|
||||||
cmap = plt.get_cmap("viridis")
|
cmap = plt.get_cmap("viridis")
|
||||||
# Sample colors evenly across the colormap
|
# Sample colors evenly across the colormap
|
||||||
indices = np.linspace(0.1, 0.9, n_classes) # Avoid extreme ends
|
indices = np.linspace(0.1, 0.9, n_classes) # Avoid extreme ends
|
||||||
base_colors = [mcolors.rgb2hex(cmap(idx)[:3]) for idx in indices]
|
base_colors = [mcolors.rgb2hex(cmap(idx).tolist()[:3]) for idx in indices]
|
||||||
|
|
||||||
# Create matplotlib colormap (for ultraplot and geopandas)
|
# Create matplotlib colormap (for ultraplot and geopandas)
|
||||||
matplotlib_cmap = mcolors.ListedColormap(base_colors)
|
matplotlib_cmap = mcolors.ListedColormap(base_colors)
|
||||||
96
src/entropice/dashboard/utils/formatters.py
Normal file
96
src/entropice/dashboard/utils/formatters.py
Normal file
|
|
@ -0,0 +1,96 @@
|
||||||
|
"""Formatters for dashboard display."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from entropice.utils.types import Grid, Model, Task
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelDisplayInfo:
|
||||||
|
keyword: Model
|
||||||
|
short: str
|
||||||
|
long: str
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.short} ({self.long})"
|
||||||
|
|
||||||
|
|
||||||
|
model_display_infos: dict[Model, ModelDisplayInfo] = {
|
||||||
|
"espa": ModelDisplayInfo(
|
||||||
|
keyword="espa", short="eSPA", long="entropy-optimal Scalable Probabilistic Approximations algorithm"
|
||||||
|
),
|
||||||
|
"xgboost": ModelDisplayInfo(keyword="xgboost", short="XGBoost", long="Extreme Gradient Boosting"),
|
||||||
|
"rf": ModelDisplayInfo(keyword="rf", short="Random Forest", long="Random Forest Classifier"),
|
||||||
|
"knn": ModelDisplayInfo(keyword="knn", short="k-NN", long="k-Nearest Neighbors Classifier"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskDisplayInfo:
|
||||||
|
keyword: Task
|
||||||
|
display_name: str
|
||||||
|
explanation: str
|
||||||
|
|
||||||
|
|
||||||
|
task_display_infos: dict[Task, TaskDisplayInfo] = {
|
||||||
|
"binary": TaskDisplayInfo(
|
||||||
|
keyword="binary",
|
||||||
|
display_name="Binary Classification",
|
||||||
|
explanation="Classify each grid cell as containing or not containing the target feature.",
|
||||||
|
),
|
||||||
|
"count": TaskDisplayInfo(
|
||||||
|
keyword="count",
|
||||||
|
display_name="Count Prediction",
|
||||||
|
explanation="Predict the number of target features present within each grid cell.",
|
||||||
|
),
|
||||||
|
"density": TaskDisplayInfo(
|
||||||
|
keyword="density",
|
||||||
|
display_name="Density Estimation",
|
||||||
|
explanation="Estimate the density of target features within each grid cell.",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingResultDisplayInfo:
|
||||||
|
task: Task
|
||||||
|
model: Model
|
||||||
|
grid: Grid
|
||||||
|
level: int
|
||||||
|
timestamp: datetime
|
||||||
|
|
||||||
|
def get_display_name(self, format_type: Literal["task_first", "model_first"] = "task_first") -> str:
|
||||||
|
task = self.task.capitalize()
|
||||||
|
model = self.model.upper()
|
||||||
|
grid = self.grid.capitalize()
|
||||||
|
level = self.level
|
||||||
|
timestamp = self.timestamp.strftime("%Y-%m-%d %H:%M")
|
||||||
|
|
||||||
|
if format_type == "model_first":
|
||||||
|
return f"{model} - {task} - {grid}-{level} ({timestamp})"
|
||||||
|
else: # task_first
|
||||||
|
return f"{task} - {model} - {grid}-{level} ({timestamp})"
|
||||||
|
|
||||||
|
|
||||||
|
def format_metric_name(metric: str) -> str:
|
||||||
|
"""Format metric name for display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metric: Raw metric name (e.g., 'f1_micro', 'precision_macro').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted metric name (e.g., 'F1 Micro', 'Precision Macro').
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Split by underscore and capitalize each part
|
||||||
|
parts = metric.split("_")
|
||||||
|
# Special handling for F1
|
||||||
|
formatted_parts = []
|
||||||
|
for part in parts:
|
||||||
|
if part.lower() == "f1":
|
||||||
|
formatted_parts.append("F1")
|
||||||
|
else:
|
||||||
|
formatted_parts.append(part.capitalize())
|
||||||
|
return " ".join(formatted_parts)
|
||||||
163
src/entropice/dashboard/utils/loaders.py
Normal file
163
src/entropice/dashboard/utils/loaders.py
Normal file
|
|
@ -0,0 +1,163 @@
|
||||||
|
"""Data utilities for Entropice dashboard."""
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import antimeridian
|
||||||
|
import geopandas as gpd
|
||||||
|
import pandas as pd
|
||||||
|
import streamlit as st
|
||||||
|
import toml
|
||||||
|
import xarray as xr
|
||||||
|
from shapely.geometry import shape
|
||||||
|
|
||||||
|
import entropice.spatial.grids
|
||||||
|
import entropice.utils.paths
|
||||||
|
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
|
||||||
|
from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
|
||||||
|
from entropice.ml.training import TrainingSettings
|
||||||
|
from entropice.utils.types import L2SourceDataset, Task
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_hex_geometry(geom):
|
||||||
|
"""Fix hexagon geometry crossing the antimeridian."""
|
||||||
|
try:
|
||||||
|
return shape(antimeridian.fix_shape(geom))
|
||||||
|
except ValueError as e:
|
||||||
|
st.error(f"Error fixing geometry: {e}")
|
||||||
|
return geom
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingResult:
|
||||||
|
path: Path
|
||||||
|
settings: TrainingSettings
|
||||||
|
results: pd.DataFrame
|
||||||
|
created_at: float
|
||||||
|
available_metrics: list[str]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_path(cls, result_path: Path) -> "TrainingResult":
|
||||||
|
"""Load a TrainingResult from a given result directory path."""
|
||||||
|
result_file = result_path / "search_results.parquet"
|
||||||
|
preds_file = result_path / "predicted_probabilities.parquet"
|
||||||
|
settings_file = result_path / "search_settings.toml"
|
||||||
|
if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]):
|
||||||
|
raise FileNotFoundError(f"Missing required files in {result_path}")
|
||||||
|
|
||||||
|
created_at = result_path.stat().st_ctime
|
||||||
|
settings = TrainingSettings(**(toml.load(settings_file)["settings"]))
|
||||||
|
results = pd.read_parquet(result_file)
|
||||||
|
|
||||||
|
available_metrics = [col.replace("mean_test_", "") for col in results.columns if col.startswith("mean_test_")]
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
path=result_path,
|
||||||
|
settings=settings,
|
||||||
|
results=results,
|
||||||
|
created_at=created_at,
|
||||||
|
available_metrics=available_metrics,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display_info(self) -> TrainingResultDisplayInfo:
|
||||||
|
return TrainingResultDisplayInfo(
|
||||||
|
task=self.settings.task,
|
||||||
|
model=self.settings.model,
|
||||||
|
grid=self.settings.grid,
|
||||||
|
level=self.settings.level,
|
||||||
|
timestamp=datetime.fromtimestamp(self.created_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_best_model(self) -> object | None:
|
||||||
|
"""Load the best model from a training result."""
|
||||||
|
model_file = self.path / "best_estimator_model.pkl"
|
||||||
|
if not model_file.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(model_file, "rb") as f:
|
||||||
|
model = pickle.load(f)
|
||||||
|
return model
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Error loading model: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_model_state(self) -> xr.Dataset | None:
|
||||||
|
"""Load the model state from a training result."""
|
||||||
|
state_file = self.path / "best_estimator_state.nc"
|
||||||
|
if not state_file.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
state = xr.open_dataset(state_file, engine="h5netcdf")
|
||||||
|
return state
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Error loading model state: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_predictions(self) -> pd.DataFrame | None:
|
||||||
|
"""Load predictions from a training result."""
|
||||||
|
preds_file = self.path / "predicted_probabilities.parquet"
|
||||||
|
if not preds_file.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
preds = pd.read_parquet(preds_file)
|
||||||
|
return preds
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Error loading predictions: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_data
|
||||||
|
def load_all_training_results() -> list[TrainingResult]:
|
||||||
|
results_dir = entropice.utils.paths.RESULTS_DIR
|
||||||
|
training_results: list[TrainingResult] = []
|
||||||
|
for result_path in results_dir.iterdir():
|
||||||
|
if not result_path.is_dir():
|
||||||
|
continue
|
||||||
|
training_result = TrainingResult.from_path(result_path)
|
||||||
|
training_results.append(training_result)
|
||||||
|
|
||||||
|
# Sort by creation time (most recent first)
|
||||||
|
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
|
||||||
|
return training_results
|
||||||
|
|
||||||
|
|
||||||
|
def load_all_training_data(e: DatasetEnsemble) -> dict[Task, CategoricalTrainingDataset]:
|
||||||
|
"""Load training data for all three tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e: DatasetEnsemble object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"binary": e.create_cat_training_dataset("binary", device="cpu"),
|
||||||
|
"count": e.create_cat_training_dataset("count", device="cpu"),
|
||||||
|
"density": e.create_cat_training_dataset("density", device="cpu"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_source_data(e: DatasetEnsemble, source: L2SourceDataset) -> tuple[xr.Dataset, gpd.GeoDataFrame]:
|
||||||
|
"""Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e: DatasetEnsemble object.
|
||||||
|
source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
xarray.Dataset with the raw data for the specified source.
|
||||||
|
|
||||||
|
"""
|
||||||
|
targets = e._read_target()
|
||||||
|
|
||||||
|
# Load the member data lazily to get metadata
|
||||||
|
ds = e._read_member(source, targets, lazy=False)
|
||||||
|
|
||||||
|
return ds, targets
|
||||||
482
src/entropice/dashboard/utils/stats.py
Normal file
482
src/entropice/dashboard/utils/stats.py
Normal file
|
|
@ -0,0 +1,482 @@
|
||||||
|
"""General Statistics shared across multiple dashboard pages.
|
||||||
|
|
||||||
|
- Dataset statistics: Feature Counts, Class Distributions, Temporal Coverage, all per grid-level-combination
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from typing import Literal, cast, get_args
|
||||||
|
|
||||||
|
import geopandas as gpd
|
||||||
|
import pandas as pd
|
||||||
|
import streamlit as st
|
||||||
|
import xarray as xr
|
||||||
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
|
import entropice.spatial.grids
|
||||||
|
import entropice.utils.paths
|
||||||
|
from entropice.dashboard.utils.loaders import TrainingResult
|
||||||
|
from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol
|
||||||
|
from entropice.utils.types import Grid, GridLevel, L2SourceDataset, TargetDataset, Task
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MemberStatistics:
|
||||||
|
"""Statistics for a specific dataset member."""
|
||||||
|
|
||||||
|
feature_count: int # Number of features from this member
|
||||||
|
variable_names: list[str] # Names of variables from this member
|
||||||
|
dimensions: dict[str, int] # Dimension name to size mapping
|
||||||
|
coordinates: list[str] # Names of coordinates from this member
|
||||||
|
size_bytes: int # Size of this member's data on disk in bytes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def compute(cls, grid: Grid, level: int, member: L2SourceDataset) -> "MemberStatistics":
|
||||||
|
if member == "AlphaEarth":
|
||||||
|
store = entropice.utils.paths.get_embeddings_store(grid=grid, level=level)
|
||||||
|
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
||||||
|
era5_agg = member.split("-")[1]
|
||||||
|
store = entropice.utils.paths.get_era5_stores(era5_agg, grid=grid, level=level) # ty:ignore[invalid-argument-type]
|
||||||
|
elif member == "ArcticDEM":
|
||||||
|
store = entropice.utils.paths.get_arcticdem_stores(grid=grid, level=level)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Member {member} not implemented.")
|
||||||
|
|
||||||
|
size_bytes = store.stat().st_size
|
||||||
|
ds = xr.open_zarr(store, consolidated=False)
|
||||||
|
|
||||||
|
# Delete all coordinates which are not in the dimension
|
||||||
|
for coord in ds.coords:
|
||||||
|
if coord not in ds.dims:
|
||||||
|
ds = ds.drop_vars(coord)
|
||||||
|
n_cols_member = len(ds.data_vars)
|
||||||
|
for dim in ds.sizes:
|
||||||
|
if dim != "cell_ids":
|
||||||
|
n_cols_member *= ds.sizes[dim]
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
feature_count=n_cols_member,
|
||||||
|
variable_names=list(ds.data_vars),
|
||||||
|
dimensions=dict(ds.sizes),
|
||||||
|
coordinates=list(ds.coords),
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TargetStatistics:
|
||||||
|
"""Statistics for a specific target dataset."""
|
||||||
|
|
||||||
|
training_cells: dict[Task, int] # Number of cells used for training
|
||||||
|
coverage: dict[Task, float] # Percentage of total cells covered by training data
|
||||||
|
class_counts: dict[Task, dict[str, int]] # Class name to count mapping per task
|
||||||
|
class_distribution: dict[Task, dict[str, float]] # Class name to percentage mapping per task
|
||||||
|
size_bytes: int # Size of the target dataset on disk in bytes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def compute(cls, grid: Grid, level: int, target: TargetDataset, total_cells: int) -> "TargetStatistics":
|
||||||
|
if target == "darts_rts":
|
||||||
|
target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level)
|
||||||
|
elif target == "darts_mllabels":
|
||||||
|
target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level, labels=True)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Target {target} not implemented.")
|
||||||
|
target_gdf = gpd.read_parquet(target_store)
|
||||||
|
size_bytes = target_store.stat().st_size
|
||||||
|
training_cells: dict[Task, int] = {}
|
||||||
|
training_coverage: dict[Task, float] = {}
|
||||||
|
class_counts: dict[Task, dict[str, int]] = {}
|
||||||
|
class_distribution: dict[Task, dict[str, float]] = {}
|
||||||
|
tasks = cast(list[Task], get_args(Task))
|
||||||
|
for task in tasks:
|
||||||
|
task_col = taskcol[task][target]
|
||||||
|
cov_col = covcol[target]
|
||||||
|
|
||||||
|
task_gdf = target_gdf[target_gdf[cov_col]]
|
||||||
|
training_cells[task] = len(task_gdf)
|
||||||
|
training_coverage[task] = len(task_gdf) / total_cells * 100
|
||||||
|
|
||||||
|
model_labels = task_gdf[task_col].dropna()
|
||||||
|
if task == "binary":
|
||||||
|
binned = model_labels.map({False: "No RTS", True: "RTS"}).astype("category")
|
||||||
|
elif task == "count":
|
||||||
|
binned = bin_values(model_labels.astype(int), task=task)
|
||||||
|
elif task == "density":
|
||||||
|
binned = bin_values(model_labels, task=task)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid task.")
|
||||||
|
counts = binned.value_counts()
|
||||||
|
distribution = counts / counts.sum() * 100
|
||||||
|
class_counts[task] = counts.to_dict() # ty:ignore[invalid-assignment]
|
||||||
|
class_distribution[task] = distribution.to_dict()
|
||||||
|
return TargetStatistics(
|
||||||
|
training_cells=training_cells,
|
||||||
|
coverage=training_coverage,
|
||||||
|
class_counts=class_counts,
|
||||||
|
class_distribution=class_distribution,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DatasetStatistics:
|
||||||
|
"""Statistics for a potential dataset at a specific grid and level.
|
||||||
|
|
||||||
|
These statistics are meant to be easily compute without loading a full dataset into memory.
|
||||||
|
Further, it is meant to give an overview of all potential potential dataset compositions for a given grid and level.
|
||||||
|
"""
|
||||||
|
|
||||||
|
total_features: int # Total number of features available in the dataset
|
||||||
|
total_cells: int # Total number of grid cells potentially covered
|
||||||
|
size_bytes: int # Size of the dataset on disk in bytes
|
||||||
|
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
|
||||||
|
target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_data
|
||||||
|
def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
|
||||||
|
dataset_stats: dict[GridLevel, DatasetStatistics] = {}
|
||||||
|
grid_levels: set[tuple[Grid, int]] = {
|
||||||
|
("hex", 3),
|
||||||
|
("hex", 4),
|
||||||
|
("hex", 5),
|
||||||
|
("hex", 6),
|
||||||
|
("healpix", 6),
|
||||||
|
("healpix", 7),
|
||||||
|
("healpix", 8),
|
||||||
|
("healpix", 9),
|
||||||
|
("healpix", 10),
|
||||||
|
}
|
||||||
|
for grid, level in grid_levels:
|
||||||
|
with stopwatch(f"Loading statistics for grid={grid}, level={level}"):
|
||||||
|
grid_gdf = entropice.spatial.grids.open(grid, level) # Ensure grid is registered
|
||||||
|
total_cells = len(grid_gdf)
|
||||||
|
assert total_cells > 0, "Grid must contain at least one cell."
|
||||||
|
target_statistics: dict[TargetDataset, TargetStatistics] = {}
|
||||||
|
targets = cast(list[TargetDataset], get_args(TargetDataset))
|
||||||
|
for target in targets:
|
||||||
|
target_statistics[target] = TargetStatistics.compute(
|
||||||
|
grid=grid, level=level, target=target, total_cells=total_cells
|
||||||
|
)
|
||||||
|
member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
|
||||||
|
members = cast(list[L2SourceDataset], get_args(L2SourceDataset))
|
||||||
|
for member in members:
|
||||||
|
member_statistics[member] = MemberStatistics.compute(grid=grid, level=level, member=member)
|
||||||
|
|
||||||
|
total_features = sum(ms.feature_count for ms in member_statistics.values())
|
||||||
|
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum(
|
||||||
|
ts.size_bytes for ts in target_statistics.values()
|
||||||
|
)
|
||||||
|
grid_level: GridLevel = cast(GridLevel, f"{grid}{level}")
|
||||||
|
dataset_stats[grid_level] = DatasetStatistics(
|
||||||
|
total_features=total_features,
|
||||||
|
total_cells=total_cells,
|
||||||
|
size_bytes=total_size_bytes,
|
||||||
|
members=member_statistics,
|
||||||
|
target=target_statistics,
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset_stats
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class EnsembleMemberStatistics:
|
||||||
|
n_features: int # Number of features from this member in the ensemble
|
||||||
|
p_features: float # Percentage of features from this member in the ensemble
|
||||||
|
n_nanrows: int # Number of rows which contain any NaN
|
||||||
|
size_bytes: int # Size of this member's data in the ensemble in bytes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def compute(
|
||||||
|
cls,
|
||||||
|
dataset: gpd.GeoDataFrame,
|
||||||
|
member: L2SourceDataset,
|
||||||
|
n_features: int,
|
||||||
|
) -> "EnsembleMemberStatistics":
|
||||||
|
if member == "AlphaEarth":
|
||||||
|
member_dataset = dataset[dataset.columns.str.startswith("embeddings_")]
|
||||||
|
elif member == "ArcticDEM":
|
||||||
|
member_dataset = dataset[dataset.columns.str.startswith("arcticdem_")]
|
||||||
|
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
||||||
|
era5_cols = dataset.columns.str.startswith("era5_")
|
||||||
|
cols_with_three_splits = era5_cols & (dataset.columns.str.count("_") == 2)
|
||||||
|
cols_with_four_splits = era5_cols & (dataset.columns.str.count("_") == 3)
|
||||||
|
cols_with_summer_winter = era5_cols & (dataset.columns.str.contains("_summer|_winter"))
|
||||||
|
if member == "ERA5-yearly":
|
||||||
|
member_dataset = dataset[cols_with_three_splits]
|
||||||
|
elif member == "ERA5-seasonal":
|
||||||
|
member_dataset = dataset[cols_with_summer_winter & cols_with_four_splits]
|
||||||
|
elif member == "ERA5-shoulder":
|
||||||
|
member_dataset = dataset[~cols_with_summer_winter & cols_with_four_splits]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Member {member} not implemented.")
|
||||||
|
size_bytes_member = member_dataset.memory_usage(deep=True).sum()
|
||||||
|
n_features_member = len(member_dataset.columns)
|
||||||
|
p_features_member = n_features_member / n_features
|
||||||
|
n_rows_with_nan_member = member_dataset.isna().any(axis=1).sum()
|
||||||
|
return EnsembleMemberStatistics(
|
||||||
|
n_features=n_features_member,
|
||||||
|
p_features=p_features_member,
|
||||||
|
n_nanrows=n_rows_with_nan_member,
|
||||||
|
size_bytes=size_bytes_member,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class EnsembleDatasetStatistics:
|
||||||
|
"""Statistics for a specified composition / ensemble at a specific grid and level.
|
||||||
|
|
||||||
|
These statistics are meant to be only computed on user demand, thus by loading a dataset into memory.
|
||||||
|
That way, the real number of features and cells after NaN filtering can be reported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
n_features: int # Number of features in the dataset
|
||||||
|
n_cells: int # Number of grid cells covered
|
||||||
|
n_nanrows: int # Number of rows which contain any NaN
|
||||||
|
size_bytes: int # Size of the dataset on disk in bytes
|
||||||
|
members: dict[L2SourceDataset, EnsembleMemberStatistics] # Statistics per source dataset member
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def compute(cls, ensemble: DatasetEnsemble, dataset: gpd.GeoDataFrame | None = None) -> "EnsembleDatasetStatistics":
|
||||||
|
dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
|
||||||
|
# Assert that no column in all-nan
|
||||||
|
assert not dataset.isna().all("index").any(), "Some input columns are all NaN"
|
||||||
|
|
||||||
|
size_bytes = dataset.memory_usage(deep=True).sum()
|
||||||
|
n_features = len(dataset.columns)
|
||||||
|
n_cells = len(dataset)
|
||||||
|
|
||||||
|
# Number of rows which contain any NaN
|
||||||
|
n_rows_with_nan = dataset.isna().any(axis=1).sum()
|
||||||
|
|
||||||
|
member_statistics: dict[L2SourceDataset, EnsembleMemberStatistics] = {}
|
||||||
|
for member in ensemble.members:
|
||||||
|
member_statistics[member] = EnsembleMemberStatistics.compute(
|
||||||
|
dataset=dataset,
|
||||||
|
member=member,
|
||||||
|
n_features=n_features,
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
n_features=n_features,
|
||||||
|
n_cells=n_cells,
|
||||||
|
n_nanrows=n_rows_with_nan,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
members=member_statistics,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TrainingDatasetStatistics:
|
||||||
|
n_samples: int # Total number of samples in the dataset
|
||||||
|
n_features: int # Number of features
|
||||||
|
feature_names: list[str] # Names of all features
|
||||||
|
n_training_samples: int # Number of cells used for training
|
||||||
|
n_test_samples: int # Number of cells used for testing
|
||||||
|
train_test_ratio: float # Ratio of training to test samples
|
||||||
|
|
||||||
|
class_labels: list[str] # Ordered list of class labels
|
||||||
|
class_intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] # Min/max raw values per class
|
||||||
|
n_classes: int # Number of classes
|
||||||
|
|
||||||
|
training_class_counts: dict[str, int] # Class counts in training set
|
||||||
|
training_class_distribution: dict[str, float] # Class percentages in training set
|
||||||
|
test_class_counts: dict[str, int] # Class counts in test set
|
||||||
|
test_class_distribution: dict[str, float] # Class percentages in test set
|
||||||
|
|
||||||
|
raw_value_min: float # Minimum raw target value
|
||||||
|
raw_value_max: float # Maximum raw target value
|
||||||
|
raw_value_mean: float # Mean raw target value
|
||||||
|
raw_value_median: float # Median raw target value
|
||||||
|
raw_value_std: float # Standard deviation of raw target values
|
||||||
|
|
||||||
|
imbalance_ratio: float # Smallest class count / largest class count (overall)
|
||||||
|
size_bytes: int # Total memory usage of features in bytes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def compute(
|
||||||
|
cls, ensemble: DatasetEnsemble, task: Task, dataset: gpd.GeoDataFrame | None = None
|
||||||
|
) -> "TrainingDatasetStatistics":
|
||||||
|
dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
|
||||||
|
categorical_dataset = ensemble._cat_and_split(dataset, task=task, device="cpu")
|
||||||
|
|
||||||
|
# Sample counts
|
||||||
|
n_samples = len(categorical_dataset)
|
||||||
|
n_training_samples = len(categorical_dataset.y.train)
|
||||||
|
n_test_samples = len(categorical_dataset.y.test)
|
||||||
|
train_test_ratio = n_training_samples / n_test_samples if n_test_samples > 0 else 0.0
|
||||||
|
|
||||||
|
# Feature statistics
|
||||||
|
n_features = len(categorical_dataset.X.data.columns)
|
||||||
|
feature_names = list(categorical_dataset.X.data.columns)
|
||||||
|
size_bytes = categorical_dataset.X.data.memory_usage(deep=True).sum()
|
||||||
|
|
||||||
|
# Class information
|
||||||
|
class_labels = categorical_dataset.y.labels
|
||||||
|
class_intervals = categorical_dataset.y.intervals
|
||||||
|
n_classes = len(class_labels)
|
||||||
|
|
||||||
|
# Training class distribution
|
||||||
|
train_y_series = pd.Series(categorical_dataset.y.train)
|
||||||
|
train_counts = train_y_series.value_counts().sort_index()
|
||||||
|
training_class_counts = {class_labels[i]: int(train_counts.get(i, 0)) for i in range(n_classes)}
|
||||||
|
train_total = sum(training_class_counts.values())
|
||||||
|
training_class_distribution = {
|
||||||
|
k: (v / train_total * 100) if train_total > 0 else 0.0 for k, v in training_class_counts.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test class distribution
|
||||||
|
test_y_series = pd.Series(categorical_dataset.y.test)
|
||||||
|
test_counts = test_y_series.value_counts().sort_index()
|
||||||
|
test_class_counts = {class_labels[i]: int(test_counts.get(i, 0)) for i in range(n_classes)}
|
||||||
|
test_total = sum(test_class_counts.values())
|
||||||
|
test_class_distribution = {
|
||||||
|
k: (v / test_total * 100) if test_total > 0 else 0.0 for k, v in test_class_counts.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Raw value statistics
|
||||||
|
raw_values = categorical_dataset.y.raw_values
|
||||||
|
raw_value_min = float(raw_values.min())
|
||||||
|
raw_value_max = float(raw_values.max())
|
||||||
|
raw_value_mean = float(raw_values.mean())
|
||||||
|
raw_value_median = float(raw_values.median())
|
||||||
|
raw_value_std = float(raw_values.std())
|
||||||
|
|
||||||
|
# Imbalance ratio (smallest class / largest class across both splits)
|
||||||
|
all_counts = list(training_class_counts.values()) + list(test_class_counts.values())
|
||||||
|
nonzero_counts = [c for c in all_counts if c > 0]
|
||||||
|
imbalance_ratio = min(nonzero_counts) / max(nonzero_counts) if nonzero_counts else 0.0
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
n_samples=n_samples,
|
||||||
|
n_features=n_features,
|
||||||
|
feature_names=feature_names,
|
||||||
|
n_training_samples=n_training_samples,
|
||||||
|
n_test_samples=n_test_samples,
|
||||||
|
train_test_ratio=train_test_ratio,
|
||||||
|
class_labels=class_labels,
|
||||||
|
class_intervals=class_intervals,
|
||||||
|
n_classes=n_classes,
|
||||||
|
training_class_counts=training_class_counts,
|
||||||
|
training_class_distribution=training_class_distribution,
|
||||||
|
test_class_counts=test_class_counts,
|
||||||
|
test_class_distribution=test_class_distribution,
|
||||||
|
raw_value_min=raw_value_min,
|
||||||
|
raw_value_max=raw_value_max,
|
||||||
|
raw_value_mean=raw_value_mean,
|
||||||
|
raw_value_median=raw_value_median,
|
||||||
|
raw_value_std=raw_value_std,
|
||||||
|
imbalance_ratio=imbalance_ratio,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CVMetricStatistics:
|
||||||
|
best_score: float
|
||||||
|
mean_score: float
|
||||||
|
std_score: float
|
||||||
|
worst_score: float
|
||||||
|
median_score: float
|
||||||
|
mean_cv_std: float | None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def compute(cls, result: TrainingResult, metric: str) -> "CVMetricStatistics":
|
||||||
|
"""Get cross-validation statistics for a metric."""
|
||||||
|
score_col = f"mean_test_{metric}"
|
||||||
|
std_col = f"std_test_{metric}"
|
||||||
|
|
||||||
|
if score_col not in result.results.columns:
|
||||||
|
raise ValueError(f"Metric {metric} not found in results.")
|
||||||
|
|
||||||
|
best_score = result.results[score_col].max()
|
||||||
|
mean_score = result.results[score_col].mean()
|
||||||
|
std_score = result.results[score_col].std()
|
||||||
|
worst_score = result.results[score_col].min()
|
||||||
|
median_score = result.results[score_col].median()
|
||||||
|
|
||||||
|
mean_cv_std = None
|
||||||
|
if std_col in result.results.columns:
|
||||||
|
mean_cv_std = result.results[std_col].mean()
|
||||||
|
|
||||||
|
return CVMetricStatistics(
|
||||||
|
best_score=best_score,
|
||||||
|
mean_score=mean_score,
|
||||||
|
std_score=std_score,
|
||||||
|
worst_score=worst_score,
|
||||||
|
median_score=median_score,
|
||||||
|
mean_cv_std=mean_cv_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ParameterSpaceSummary:
|
||||||
|
parameter: str
|
||||||
|
type: Literal["Numeric", "Categorical"]
|
||||||
|
min: float | None
|
||||||
|
max: float | None
|
||||||
|
mean: float | None
|
||||||
|
unique_values: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def compute(cls, result: TrainingResult, param_col: str) -> "ParameterSpaceSummary":
|
||||||
|
param_name = param_col.replace("param_", "")
|
||||||
|
param_values = result.results[param_col].dropna()
|
||||||
|
|
||||||
|
if pd.api.types.is_numeric_dtype(param_values):
|
||||||
|
return ParameterSpaceSummary(
|
||||||
|
parameter=param_name,
|
||||||
|
type="Numeric",
|
||||||
|
min=param_values.min(),
|
||||||
|
max=param_values.max(),
|
||||||
|
mean=param_values.mean(),
|
||||||
|
unique_values=param_values.nunique(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
unique_vals = param_values.unique()
|
||||||
|
return ParameterSpaceSummary(
|
||||||
|
parameter=param_name,
|
||||||
|
type="Categorical",
|
||||||
|
min=None,
|
||||||
|
max=None,
|
||||||
|
mean=None,
|
||||||
|
unique_values=len(unique_vals),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CVResultsStatistics:
|
||||||
|
metrics: dict[str, CVMetricStatistics]
|
||||||
|
parameter_summary: list[ParameterSpaceSummary]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def compute(cls, result: TrainingResult) -> "CVResultsStatistics":
|
||||||
|
"""Get cross-validation statistics for a metric."""
|
||||||
|
metrics = result.available_metrics
|
||||||
|
metric_stats: dict[str, CVMetricStatistics] = {}
|
||||||
|
for metric in metrics:
|
||||||
|
metric_stats[metric] = CVMetricStatistics.compute(result, metric)
|
||||||
|
|
||||||
|
param_cols = [col for col in result.results.columns if col.startswith("param_") and col != "params"]
|
||||||
|
summary_data = []
|
||||||
|
for param_col in param_cols:
|
||||||
|
summary_data.append(ParameterSpaceSummary.compute(result, param_col))
|
||||||
|
|
||||||
|
return CVResultsStatistics(metrics=metric_stats, parameter_summary=summary_data)
|
||||||
|
|
||||||
|
def metrics_to_dataframe(self) -> pd.DataFrame:
|
||||||
|
"""Convert metric statistics to a DataFrame."""
|
||||||
|
data = defaultdict(list)
|
||||||
|
for metric, stats in self.metrics.items():
|
||||||
|
data["Metric"].append(metric)
|
||||||
|
data["Best Score"].append(stats.best_score)
|
||||||
|
data["Mean Score"].append(stats.mean_score)
|
||||||
|
data["Std Dev"].append(stats.std_score)
|
||||||
|
data["Worst Score"].append(stats.worst_score)
|
||||||
|
data["Median Score"].append(stats.median_score)
|
||||||
|
data["Mean CV Std Dev"].append(stats.mean_cv_std)
|
||||||
|
|
||||||
|
return pd.DataFrame(data)
|
||||||
|
|
||||||
|
def parameters_to_dataframe(self) -> pd.DataFrame:
|
||||||
|
"""Convert parameter summary to a DataFrame."""
|
||||||
|
return pd.DataFrame([asdict(p) for p in self.parameter_summary])
|
||||||
|
|
@ -1,232 +0,0 @@
|
||||||
"""Training utilities for dashboard."""
|
|
||||||
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import streamlit as st
|
|
||||||
import xarray as xr
|
|
||||||
|
|
||||||
from entropice.dashboard.utils.data import TrainingResult
|
|
||||||
|
|
||||||
|
|
||||||
def format_metric_name(metric: str) -> str:
|
|
||||||
"""Format metric name for display.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metric: Raw metric name (e.g., 'f1_micro', 'precision_macro').
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted metric name (e.g., 'F1 Micro', 'Precision Macro').
|
|
||||||
|
|
||||||
"""
|
|
||||||
# Split by underscore and capitalize each part
|
|
||||||
parts = metric.split("_")
|
|
||||||
# Special handling for F1
|
|
||||||
formatted_parts = []
|
|
||||||
for part in parts:
|
|
||||||
if part.lower() == "f1":
|
|
||||||
formatted_parts.append("F1")
|
|
||||||
else:
|
|
||||||
formatted_parts.append(part.capitalize())
|
|
||||||
return " ".join(formatted_parts)
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_metrics(results: pd.DataFrame) -> list[str]:
|
|
||||||
"""Get list of available metrics from results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results: DataFrame with CV results.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of metric names (without 'mean_test_' prefix).
|
|
||||||
|
|
||||||
"""
|
|
||||||
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
|
|
||||||
return [col.replace("mean_test_", "") for col in score_cols]
|
|
||||||
|
|
||||||
|
|
||||||
def load_best_model(result: TrainingResult):
|
|
||||||
"""Load the best model from a training result.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: TrainingResult object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The loaded model object, or None if loading fails.
|
|
||||||
|
|
||||||
"""
|
|
||||||
model_file = result.path / "best_estimator_model.pkl"
|
|
||||||
if not model_file.exists():
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(model_file, "rb") as f:
|
|
||||||
model = pickle.load(f)
|
|
||||||
return model
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"Error loading model: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_state(result: TrainingResult) -> xr.Dataset | None:
|
|
||||||
"""Load the model state from a training result.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: TrainingResult object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
xarray Dataset with model state, or None if not available.
|
|
||||||
|
|
||||||
"""
|
|
||||||
state_file = result.path / "best_estimator_state.nc"
|
|
||||||
if not state_file.exists():
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
state = xr.open_dataset(state_file, engine="h5netcdf")
|
|
||||||
return state
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"Error loading model state: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def load_predictions(result: TrainingResult) -> pd.DataFrame | None:
|
|
||||||
"""Load predictions from a training result.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: TrainingResult object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DataFrame with predictions, or None if not available.
|
|
||||||
|
|
||||||
"""
|
|
||||||
preds_file = result.path / "predicted_probabilities.parquet"
|
|
||||||
if not preds_file.exists():
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
preds = pd.read_parquet(preds_file)
|
|
||||||
return preds
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"Error loading predictions: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_space_summary(results: pd.DataFrame) -> pd.DataFrame:
|
|
||||||
"""Get summary of parameter space explored.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results: DataFrame with CV results.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DataFrame with parameter ranges and statistics.
|
|
||||||
|
|
||||||
"""
|
|
||||||
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
|
|
||||||
|
|
||||||
summary_data = []
|
|
||||||
for param_col in param_cols:
|
|
||||||
param_name = param_col.replace("param_", "")
|
|
||||||
param_values = results[param_col].dropna()
|
|
||||||
|
|
||||||
if pd.api.types.is_numeric_dtype(param_values):
|
|
||||||
summary_data.append(
|
|
||||||
{
|
|
||||||
"Parameter": param_name,
|
|
||||||
"Type": "Numeric",
|
|
||||||
"Min": f"{param_values.min():.2e}",
|
|
||||||
"Max": f"{param_values.max():.2e}",
|
|
||||||
"Mean": f"{param_values.mean():.2e}",
|
|
||||||
"Unique Values": param_values.nunique(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
unique_vals = param_values.unique()
|
|
||||||
summary_data.append(
|
|
||||||
{
|
|
||||||
"Parameter": param_name,
|
|
||||||
"Type": "Categorical",
|
|
||||||
"Min": "-",
|
|
||||||
"Max": "-",
|
|
||||||
"Mean": "-",
|
|
||||||
"Unique Values": len(unique_vals),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return pd.DataFrame(summary_data)
|
|
||||||
|
|
||||||
|
|
||||||
def get_cv_statistics(results: pd.DataFrame, metric: str) -> dict:
|
|
||||||
"""Get cross-validation statistics for a metric.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results: DataFrame with CV results.
|
|
||||||
metric: Metric name (without 'mean_test_' prefix).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with CV statistics.
|
|
||||||
|
|
||||||
"""
|
|
||||||
score_col = f"mean_test_{metric}"
|
|
||||||
std_col = f"std_test_{metric}"
|
|
||||||
|
|
||||||
if score_col not in results.columns:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
stats = {
|
|
||||||
"best_score": results[score_col].max(),
|
|
||||||
"mean_score": results[score_col].mean(),
|
|
||||||
"std_score": results[score_col].std(),
|
|
||||||
"worst_score": results[score_col].min(),
|
|
||||||
"median_score": results[score_col].median(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if std_col in results.columns:
|
|
||||||
stats["mean_cv_std"] = results[std_col].mean()
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_results_for_plotting(results: pd.DataFrame, k_bin_width: int = 40) -> pd.DataFrame:
|
|
||||||
"""Prepare results dataframe with binned columns for plotting.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results: DataFrame with CV results.
|
|
||||||
k_bin_width: Width of bins for initial_K parameter.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DataFrame with added binned columns.
|
|
||||||
|
|
||||||
"""
|
|
||||||
results_copy = results.copy()
|
|
||||||
|
|
||||||
# Check if we have the parameters
|
|
||||||
if "param_initial_K" in results.columns:
|
|
||||||
# Bin initial_K
|
|
||||||
k_values = results["param_initial_K"].dropna()
|
|
||||||
if len(k_values) > 0:
|
|
||||||
k_min = k_values.min()
|
|
||||||
k_max = k_values.max()
|
|
||||||
k_bins = range(int(k_min), int(k_max) + k_bin_width, k_bin_width)
|
|
||||||
results_copy["initial_K_binned"] = pd.cut(results["param_initial_K"], bins=k_bins, right=False)
|
|
||||||
|
|
||||||
if "param_eps_cl" in results.columns:
|
|
||||||
# Create logarithmic bins for eps_cl
|
|
||||||
eps_cl_values = results["param_eps_cl"].dropna()
|
|
||||||
if len(eps_cl_values) > 0 and eps_cl_values.min() > 0:
|
|
||||||
eps_cl_min = eps_cl_values.min()
|
|
||||||
eps_cl_max = eps_cl_values.max()
|
|
||||||
eps_cl_bins = np.logspace(np.log10(eps_cl_min), np.log10(eps_cl_max), num=10)
|
|
||||||
results_copy["eps_cl_binned"] = pd.cut(results["param_eps_cl"], bins=eps_cl_bins)
|
|
||||||
|
|
||||||
if "param_eps_e" in results.columns:
|
|
||||||
# Create logarithmic bins for eps_e
|
|
||||||
eps_e_values = results["param_eps_e"].dropna()
|
|
||||||
if len(eps_e_values) > 0 and eps_e_values.min() > 0:
|
|
||||||
eps_e_min = eps_e_values.min()
|
|
||||||
eps_e_max = eps_e_values.max()
|
|
||||||
eps_e_bins = np.logspace(np.log10(eps_e_min), np.log10(eps_e_max), num=10)
|
|
||||||
results_copy["eps_e_binned"] = pd.cut(results["param_eps_e"], bins=eps_e_bins)
|
|
||||||
|
|
||||||
return results_copy
|
|
||||||
|
|
@ -1,147 +1,4 @@
|
||||||
"""Data utilities for Entropice dashboard."""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import antimeridian
|
|
||||||
import pandas as pd
|
|
||||||
import streamlit as st
|
|
||||||
import toml
|
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from shapely.geometry import shape
|
|
||||||
|
|
||||||
import entropice.utils.paths
|
|
||||||
from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrainingResult:
|
|
||||||
"""Simple wrapper of training result data."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
path: Path
|
|
||||||
settings: dict
|
|
||||||
results: pd.DataFrame
|
|
||||||
created_at: float
|
|
||||||
|
|
||||||
def get_display_name(self, format_type: str = "task_first") -> str:
|
|
||||||
"""Get formatted display name for the training result.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
format_type: Either 'task_first' (for training analysis) or 'model_first' (for model state)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted name string
|
|
||||||
|
|
||||||
"""
|
|
||||||
task = self.settings.get("task", "Unknown").capitalize()
|
|
||||||
model = self.settings.get("model", "Unknown").upper()
|
|
||||||
grid = self.settings.get("grid", "Unknown").capitalize()
|
|
||||||
level = self.settings.get("level", "Unknown")
|
|
||||||
timestamp = datetime.fromtimestamp(self.created_at).strftime("%Y-%m-%d %H:%M")
|
|
||||||
|
|
||||||
if format_type == "model_first":
|
|
||||||
return f"{model} - {task} - {grid}-{level} ({timestamp})"
|
|
||||||
else: # task_first
|
|
||||||
return f"{task} - {model} - {grid}-{level} ({timestamp})"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_path(cls, result_path: Path) -> "TrainingResult":
|
|
||||||
"""Load a TrainingResult from a given result directory path."""
|
|
||||||
result_file = result_path / "search_results.parquet"
|
|
||||||
preds_file = result_path / "predicted_probabilities.parquet"
|
|
||||||
settings_file = result_path / "search_settings.toml"
|
|
||||||
if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]):
|
|
||||||
raise FileNotFoundError(f"Missing required files in {result_path}")
|
|
||||||
|
|
||||||
created_at = result_path.stat().st_ctime
|
|
||||||
settings = toml.load(settings_file)["settings"]
|
|
||||||
results = pd.read_parquet(result_file)
|
|
||||||
|
|
||||||
# Name should be "task model grid-level (created_at)"
|
|
||||||
model = settings.get("model", "Unknown").upper()
|
|
||||||
name = (
|
|
||||||
f"{settings.get('task', 'Unknown').capitalize()} -"
|
|
||||||
f" {model} -"
|
|
||||||
f" {settings.get('grid', 'Unknown').capitalize()}-{settings.get('level', 'Unknown')}"
|
|
||||||
f" ({datetime.fromtimestamp(created_at).strftime('%Y-%m-%d %H:%M')})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
name=name,
|
|
||||||
path=result_path,
|
|
||||||
settings=settings,
|
|
||||||
results=results,
|
|
||||||
created_at=created_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _fix_hex_geometry(geom):
|
|
||||||
"""Fix hexagon geometry crossing the antimeridian."""
|
|
||||||
try:
|
|
||||||
return shape(antimeridian.fix_shape(geom))
|
|
||||||
except ValueError as e:
|
|
||||||
st.error(f"Error fixing geometry: {e}")
|
|
||||||
return geom
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data
|
|
||||||
def load_all_training_results() -> list[TrainingResult]:
|
|
||||||
"""Load all training results from the results directory."""
|
|
||||||
results_dir = entropice.utils.paths.RESULTS_DIR
|
|
||||||
training_results: list[TrainingResult] = []
|
|
||||||
for result_path in results_dir.iterdir():
|
|
||||||
if not result_path.is_dir():
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
training_result = TrainingResult.from_path(result_path)
|
|
||||||
training_results.append(training_result)
|
|
||||||
except FileNotFoundError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Sort by creation time (most recent first)
|
|
||||||
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
|
|
||||||
return training_results
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data
|
|
||||||
def load_all_training_data(e: DatasetEnsemble) -> dict[str, CategoricalTrainingDataset]:
|
|
||||||
"""Load training data for all three tasks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
e: DatasetEnsemble object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"binary": e.create_cat_training_dataset("binary", device="cpu"),
|
|
||||||
"count": e.create_cat_training_dataset("count", device="cpu"),
|
|
||||||
"density": e.create_cat_training_dataset("density", device="cpu"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data
|
|
||||||
def load_source_data(e: DatasetEnsemble, source: str):
|
|
||||||
"""Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
e: DatasetEnsemble object.
|
|
||||||
source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
xarray.Dataset with the raw data for the specified source.
|
|
||||||
|
|
||||||
"""
|
|
||||||
targets = e._read_target()
|
|
||||||
|
|
||||||
# Load the member data lazily to get metadata
|
|
||||||
ds = e._read_member(source, targets, lazy=False)
|
|
||||||
|
|
||||||
return ds, targets
|
|
||||||
|
|
||||||
|
|
||||||
def _get_feature_importance_array(model_state: xr.Dataset, importance_type: str = "feature_weights") -> xr.DataArray:
|
def _get_feature_importance_array(model_state: xr.Dataset, importance_type: str = "feature_weights") -> xr.DataArray:
|
||||||
|
|
@ -203,7 +60,7 @@ def extract_embedding_features(
|
||||||
band=("feature", [f.split("_")[2] for f in embedding_features]),
|
band=("feature", [f.split("_")[2] for f in embedding_features]),
|
||||||
year=("feature", [f.split("_")[3] for f in embedding_features]),
|
year=("feature", [f.split("_")[3] for f in embedding_features]),
|
||||||
)
|
)
|
||||||
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010
|
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature")
|
||||||
return embedding_feature_array
|
return embedding_feature_array
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -419,9 +276,9 @@ def extract_era5_features(
|
||||||
era5_features_array = era5_features_array.assign_coords(
|
era5_features_array = era5_features_array.assign_coords(
|
||||||
agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]),
|
agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]),
|
||||||
)
|
)
|
||||||
era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature") # noqa: PD010
|
era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature")
|
||||||
else:
|
else:
|
||||||
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010
|
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature")
|
||||||
|
|
||||||
return era5_features_array
|
return era5_features_array
|
||||||
|
|
||||||
|
|
@ -472,7 +329,7 @@ def extract_arcticdem_features(
|
||||||
variable=("feature", [_extract_var_name(f) for f in arcticdem_features]),
|
variable=("feature", [_extract_var_name(f) for f in arcticdem_features]),
|
||||||
agg=("feature", [_extract_agg_name(f) for f in arcticdem_features]),
|
agg=("feature", [_extract_agg_name(f) for f in arcticdem_features]),
|
||||||
)
|
)
|
||||||
arcticdem_feature_array = arcticdem_feature_array.set_index(feature=["variable", "agg"]).unstack("feature") # noqa: PD010
|
arcticdem_feature_array = arcticdem_feature_array.set_index(feature=["variable", "agg"]).unstack("feature")
|
||||||
return arcticdem_feature_array
|
return arcticdem_feature_array
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -503,16 +360,3 @@ def extract_common_features(model_state: xr.Dataset, importance_type: str = "fea
|
||||||
# Extract the feature importance for common features
|
# Extract the feature importance for common features
|
||||||
common_feature_array = importance_array.sel(feature=common_features)
|
common_feature_array = importance_array.sel(feature=common_features)
|
||||||
return common_feature_array
|
return common_feature_array
|
||||||
|
|
||||||
|
|
||||||
def get_members_from_settings(settings: dict) -> list[str]:
|
|
||||||
"""Extract the list of dataset members used in training from settings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
settings: Training settings dictionary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of member dataset names (e.g., ['AlphaEarth', 'ERA5-yearly', 'ERA5-seasonal']).
|
|
||||||
|
|
||||||
"""
|
|
||||||
return settings.get("members", [])
|
|
||||||
|
|
@ -3,7 +3,6 @@
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
|
||||||
from entropice.dashboard.plots.colors import generate_unified_colormap
|
|
||||||
from entropice.dashboard.plots.model_state import (
|
from entropice.dashboard.plots.model_state import (
|
||||||
plot_arcticdem_heatmap,
|
plot_arcticdem_heatmap,
|
||||||
plot_arcticdem_summary,
|
plot_arcticdem_summary,
|
||||||
|
|
@ -17,6 +16,7 @@ from entropice.dashboard.plots.model_state import (
|
||||||
plot_era5_time_heatmap,
|
plot_era5_time_heatmap,
|
||||||
plot_top_features,
|
plot_top_features,
|
||||||
)
|
)
|
||||||
|
from entropice.dashboard.utils.colors import generate_unified_colormap
|
||||||
from entropice.dashboard.utils.data import (
|
from entropice.dashboard.utils.data import (
|
||||||
extract_arcticdem_features,
|
extract_arcticdem_features,
|
||||||
extract_common_features,
|
extract_common_features,
|
||||||
|
|
@ -8,16 +8,17 @@ import pandas as pd
|
||||||
import plotly.express as px
|
import plotly.express as px
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from entropice.dashboard.plots.colors import get_palette
|
from entropice.dashboard.utils.colors import get_palette
|
||||||
from entropice.dashboard.utils.data import load_all_training_results
|
from entropice.dashboard.utils.loaders import load_all_training_results
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
|
from entropice.utils.types import Grid, TargetDataset, Task
|
||||||
|
|
||||||
|
|
||||||
# Type definitions for dataset statistics
|
# Type definitions for dataset statistics
|
||||||
class GridConfig(TypedDict):
|
class GridConfig(TypedDict):
|
||||||
"""Grid configuration specification with metadata."""
|
"""Grid configuration specification with metadata."""
|
||||||
|
|
||||||
grid: str
|
grid: Grid
|
||||||
level: int
|
level: int
|
||||||
grid_name: str
|
grid_name: str
|
||||||
grid_sort_key: str
|
grid_sort_key: str
|
||||||
|
|
@ -116,7 +117,7 @@ def load_dataset_analysis_data() -> DatasetAnalysisCache:
|
||||||
Results are cached to avoid redundant computations across different tabs.
|
Results are cached to avoid redundant computations across different tabs.
|
||||||
"""
|
"""
|
||||||
# Define all possible grid configurations
|
# Define all possible grid configurations
|
||||||
grid_configs_raw = [
|
grid_configs_raw: list[tuple[Grid, int]] = [
|
||||||
("hex", 3),
|
("hex", 3),
|
||||||
("hex", 4),
|
("hex", 4),
|
||||||
("hex", 5),
|
("hex", 5),
|
||||||
|
|
@ -144,8 +145,8 @@ def load_dataset_analysis_data() -> DatasetAnalysisCache:
|
||||||
|
|
||||||
# Compute sample counts
|
# Compute sample counts
|
||||||
sample_counts: list[SampleCountData] = []
|
sample_counts: list[SampleCountData] = []
|
||||||
target_datasets = ["darts_rts", "darts_mllabels"]
|
target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"]
|
||||||
tasks = ["binary", "count", "density"]
|
tasks: list[Task] = ["binary", "count", "density"]
|
||||||
|
|
||||||
for grid_config in grid_configs:
|
for grid_config in grid_configs:
|
||||||
for target in target_datasets:
|
for target in target_datasets:
|
||||||
|
|
@ -810,15 +811,22 @@ def render_experiment_results(training_results):
|
||||||
|
|
||||||
def render_overview_page():
|
def render_overview_page():
|
||||||
"""Render the Overview page of the dashboard."""
|
"""Render the Overview page of the dashboard."""
|
||||||
st.title("🏡 Training Results Overview")
|
st.title("🏡 Entropice Overview")
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
Welcome to the Entropice Dashboard! This overview page provides a summary of your
|
||||||
|
training experiments and dataset analyses.
|
||||||
|
|
||||||
|
Use the sections below to explore training results, dataset sample counts,
|
||||||
|
and feature configurations.
|
||||||
|
"""
|
||||||
|
)
|
||||||
# Load training results
|
# Load training results
|
||||||
training_results = load_all_training_results()
|
training_results = load_all_training_results()
|
||||||
|
|
||||||
if not training_results:
|
if not training_results:
|
||||||
st.warning("No training results found. Please run some training experiments first.")
|
st.warning("No training results found. Please run some training experiments first.")
|
||||||
return
|
else:
|
||||||
|
|
||||||
st.write(f"Found **{len(training_results)}** training result(s)")
|
st.write(f"Found **{len(training_results)}** training result(s)")
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
@ -15,7 +15,7 @@ from entropice.dashboard.plots.source_data import (
|
||||||
render_era5_plots,
|
render_era5_plots,
|
||||||
)
|
)
|
||||||
from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map
|
from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map
|
||||||
from entropice.dashboard.utils.data import load_all_training_data, load_source_data
|
from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
from entropice.spatial import grids
|
from entropice.spatial import grids
|
||||||
|
|
||||||
|
|
@ -60,19 +60,10 @@ def render_training_data_page():
|
||||||
# Members selection
|
# Members selection
|
||||||
st.subheader("Dataset Members")
|
st.subheader("Dataset Members")
|
||||||
|
|
||||||
# Check if AlphaEarth should be disabled
|
|
||||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
|
||||||
|
|
||||||
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
selected_members = []
|
selected_members = []
|
||||||
|
|
||||||
for member in all_members:
|
for member in all_members:
|
||||||
if member == "AlphaEarth" and disable_alphaearth:
|
|
||||||
# Show disabled checkbox with explanation
|
|
||||||
st.checkbox(
|
|
||||||
member, value=False, disabled=True, help=f"AlphaEarth is not available for {grid} level {level}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
|
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
|
||||||
selected_members.append(member)
|
selected_members.append(member)
|
||||||
|
|
||||||
|
|
@ -6,7 +6,6 @@ Date: October 2025
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import cudf
|
import cudf
|
||||||
import cuml.cluster
|
import cuml.cluster
|
||||||
|
|
@ -25,6 +24,7 @@ from rich.progress import track
|
||||||
from entropice.spatial import grids
|
from entropice.spatial import grids
|
||||||
from entropice.utils import codecs
|
from entropice.utils import codecs
|
||||||
from entropice.utils.paths import get_annual_embeddings_file, get_embeddings_store
|
from entropice.utils.paths import get_annual_embeddings_file, get_embeddings_store
|
||||||
|
from entropice.utils.types import Grid
|
||||||
|
|
||||||
# Filter out the GeoDataFrame.swapaxes deprecation warning
|
# Filter out the GeoDataFrame.swapaxes deprecation warning
|
||||||
warnings.filterwarnings("ignore", message=".*GeoDataFrame.swapaxes.*", category=FutureWarning)
|
warnings.filterwarnings("ignore", message=".*GeoDataFrame.swapaxes.*", category=FutureWarning)
|
||||||
|
|
@ -55,11 +55,11 @@ def _batch_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[pd.D
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def download(grid: Literal["hex", "healpix"], level: int):
|
def download(grid: Grid, level: int):
|
||||||
"""Extract satellite embeddings from Google Earth Engine and map them to a grid.
|
"""Extract satellite embeddings from Google Earth Engine and map them to a grid.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
grid (Grid): The grid type to use.
|
||||||
level (int): The grid level to use.
|
level (int): The grid level to use.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
@ -126,11 +126,11 @@ def download(grid: Literal["hex", "healpix"], level: int):
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def combine_to_zarr(grid: Literal["hex", "healpix"], level: int):
|
def combine_to_zarr(grid: Grid, level: int):
|
||||||
"""Combine yearly embeddings parquet files into a single zarr store.
|
"""Combine yearly embeddings parquet files into a single zarr store.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
grid (Grid): The grid type to use.
|
||||||
level (int): The grid level to use.
|
level (int): The grid level to use.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ import datetime
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
import cupy_xarray
|
import cupy_xarray
|
||||||
|
|
@ -32,6 +31,7 @@ from entropice.spatial import grids, watermask
|
||||||
from entropice.spatial.aggregators import _Aggregations, aggregate_raster_into_grid
|
from entropice.spatial.aggregators import _Aggregations, aggregate_raster_into_grid
|
||||||
from entropice.utils import codecs
|
from entropice.utils import codecs
|
||||||
from entropice.utils.paths import get_arcticdem_stores
|
from entropice.utils.paths import get_arcticdem_stores
|
||||||
|
from entropice.utils.types import Grid
|
||||||
|
|
||||||
traceback.install(show_locals=True, suppress=[cyclopts])
|
traceback.install(show_locals=True, suppress=[cyclopts])
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
@ -330,7 +330,7 @@ def _open_adem():
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def aggregate(grid: Literal["hex", "healpix"], level: int, concurrent_partitions: int = 20):
|
def aggregate(grid: Grid, level: int, concurrent_partitions: int = 20):
|
||||||
mp.set_start_method("forkserver", force=True)
|
mp.set_start_method("forkserver", force=True)
|
||||||
with (
|
with (
|
||||||
dd.LocalCluster(n_workers=1, threads_per_worker=32, memory_limit="20GB") as cluster,
|
dd.LocalCluster(n_workers=1, threads_per_worker=32, memory_limit="20GB") as cluster,
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,6 @@ Author: Tobias Hölzer
|
||||||
Date: October 2025
|
Date: October 2025
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import cyclopts
|
import cyclopts
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
@ -17,6 +15,7 @@ from stopuhr import stopwatch
|
||||||
|
|
||||||
from entropice.spatial import grids
|
from entropice.spatial import grids
|
||||||
from entropice.utils.paths import darts_ml_training_labels_repo, dartsl2_cov_file, dartsl2_file, get_darts_rts_file
|
from entropice.utils.paths import darts_ml_training_labels_repo, dartsl2_cov_file, dartsl2_file, get_darts_rts_file
|
||||||
|
from entropice.utils.types import Grid
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
@ -25,11 +24,11 @@ cli = cyclopts.App(name="darts-rts")
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
|
def extract_darts_rts(grid: Grid, level: int):
|
||||||
"""Extract RTS labels from DARTS dataset.
|
"""Extract RTS labels from DARTS dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
grid (Grid): The grid type to use.
|
||||||
level (int): The grid level to use.
|
level (int): The grid level to use.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
@ -91,7 +90,7 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def extract_darts_mllabels(grid: Literal["hex", "healpix"], level: int):
|
def extract_darts_mllabels(grid: Grid, level: int):
|
||||||
with stopwatch("Load data"):
|
with stopwatch("Load data"):
|
||||||
grid_gdf = grids.open(grid, level)
|
grid_gdf = grids.open(grid, level)
|
||||||
darts_mllabels = (
|
darts_mllabels = (
|
||||||
|
|
|
||||||
|
|
@ -101,6 +101,7 @@ from entropice.spatial.aggregators import _Aggregations, aggregate_raster_into_g
|
||||||
from entropice.spatial.xvec import to_xvec
|
from entropice.spatial.xvec import to_xvec
|
||||||
from entropice.utils import codecs
|
from entropice.utils import codecs
|
||||||
from entropice.utils.paths import FIGURES_DIR, get_era5_stores
|
from entropice.utils.paths import FIGURES_DIR, get_era5_stores
|
||||||
|
from entropice.utils.types import Grid
|
||||||
|
|
||||||
traceback.install(show_locals=True, suppress=[cyclopts, xr, pd, cProfile])
|
traceback.install(show_locals=True, suppress=[cyclopts, xr, pd, cProfile])
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
@ -579,7 +580,7 @@ def enrich(n_workers: int = 10, monthly: bool = True, yearly: bool = True, daily
|
||||||
|
|
||||||
@cli.command
|
@cli.command
|
||||||
def viz(
|
def viz(
|
||||||
grid: Literal["hex", "healpix"] | None = None,
|
grid: Grid | None = None,
|
||||||
level: int | None = None,
|
level: int | None = None,
|
||||||
agg: Literal["daily", "monthly", "yearly", "summer", "winter", "seasonal", "shoulder"] = "monthly",
|
agg: Literal["daily", "monthly", "yearly", "summer", "winter", "seasonal", "shoulder"] = "monthly",
|
||||||
high_qual: bool = False,
|
high_qual: bool = False,
|
||||||
|
|
@ -587,7 +588,7 @@ def viz(
|
||||||
"""Visualize a small overview of ERA5 variables for a given aggregation.
|
"""Visualize a small overview of ERA5 variables for a given aggregation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grid (Literal["hex", "healpix"], optional): Grid type for spatial representation.
|
grid (Grid, optional): Grid type for spatial representation.
|
||||||
If provided along with level, the ERA5 data will be decoded onto the specified grid.
|
If provided along with level, the ERA5 data will be decoded onto the specified grid.
|
||||||
level (int, optional): Level of the grid for spatial representation.
|
level (int, optional): Level of the grid for spatial representation.
|
||||||
If provided along with grid, the ERA5 data will be decoded onto the specified grid.
|
If provided along with grid, the ERA5 data will be decoded onto the specified grid.
|
||||||
|
|
@ -653,7 +654,7 @@ def _correct_longs(ds: xr.Dataset) -> xr.Dataset:
|
||||||
|
|
||||||
@cli.command
|
@cli.command
|
||||||
def spatial_agg(
|
def spatial_agg(
|
||||||
grid: Literal["hex", "healpix"],
|
grid: Grid,
|
||||||
level: int,
|
level: int,
|
||||||
concurrent_partitions: int = 20,
|
concurrent_partitions: int = 20,
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ import hashlib
|
||||||
import json
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from functools import cached_property, lru_cache
|
from functools import cached_property
|
||||||
from typing import Literal, TypedDict
|
from typing import Literal, TypedDict
|
||||||
|
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
|
|
@ -32,6 +32,7 @@ from sklearn import set_config
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
import entropice.utils.paths
|
import entropice.utils.paths
|
||||||
|
from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
@ -41,6 +42,27 @@ set_config(array_api_dispatch=True)
|
||||||
sns.set_theme("talk", "whitegrid")
|
sns.set_theme("talk", "whitegrid")
|
||||||
|
|
||||||
|
|
||||||
|
covcol: dict[TargetDataset, str] = {
|
||||||
|
"darts_rts": "darts_has_coverage",
|
||||||
|
"darts_mllabels": "dartsml_has_coverage",
|
||||||
|
}
|
||||||
|
|
||||||
|
taskcol: dict[Task, dict[TargetDataset, str]] = {
|
||||||
|
"binary": {
|
||||||
|
"darts_rts": "darts_has_rts",
|
||||||
|
"darts_mllabels": "dartsml_has_rts",
|
||||||
|
},
|
||||||
|
"count": {
|
||||||
|
"darts_rts": "darts_rts_count",
|
||||||
|
"darts_mllabels": "dartsml_rts_count",
|
||||||
|
},
|
||||||
|
"density": {
|
||||||
|
"darts_rts": "darts_rts_density",
|
||||||
|
"darts_mllabels": "dartsml_rts_density",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
|
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
|
||||||
time_index = pd.DatetimeIndex(df.index.get_level_values("time"))
|
time_index = pd.DatetimeIndex(df.index.get_level_values("time"))
|
||||||
if temporal == "yearly":
|
if temporal == "yearly":
|
||||||
|
|
@ -53,10 +75,6 @@ def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "
|
||||||
return time_index.month.map(lambda x: shoulder_seasons.get(x)).str.cat(time_index.year.astype(str), sep="_")
|
return time_index.month.map(lambda x: shoulder_seasons.get(x)).str.cat(time_index.year.astype(str), sep="_")
|
||||||
|
|
||||||
|
|
||||||
type L2Dataset = Literal["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
|
||||||
type Task = Literal["binary", "count", "density"]
|
|
||||||
|
|
||||||
|
|
||||||
def bin_values(
|
def bin_values(
|
||||||
values: pd.Series,
|
values: pd.Series,
|
||||||
task: Literal["count", "density"],
|
task: Literal["count", "density"],
|
||||||
|
|
@ -144,7 +162,7 @@ class DatasetInputs:
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class CategoricalTrainingDataset:
|
class CategoricalTrainingDataset:
|
||||||
dataset: pd.DataFrame
|
dataset: gpd.GeoDataFrame
|
||||||
X: DatasetInputs
|
X: DatasetInputs
|
||||||
y: DatasetLabels
|
y: DatasetLabels
|
||||||
z: pd.Series
|
z: pd.Series
|
||||||
|
|
@ -162,12 +180,12 @@ class DatasetStats(TypedDict):
|
||||||
|
|
||||||
|
|
||||||
@cyclopts.Parameter("*")
|
@cyclopts.Parameter("*")
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class DatasetEnsemble:
|
class DatasetEnsemble:
|
||||||
grid: Literal["hex", "healpix"]
|
grid: Grid
|
||||||
level: int
|
level: int
|
||||||
target: Literal["darts_rts", "darts_mllabels"]
|
target: Literal["darts_rts", "darts_mllabels"]
|
||||||
members: list[L2Dataset] = field(
|
members: list[L2SourceDataset] = field(
|
||||||
default_factory=lambda: ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
default_factory=lambda: ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
)
|
)
|
||||||
dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict)
|
dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict)
|
||||||
|
|
@ -186,19 +204,12 @@ class DatasetEnsemble:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def covcol(self) -> str:
|
def covcol(self) -> str:
|
||||||
return "dartsml_has_coverage" if self.target == "darts_mllabels" else "darts_has_coverage"
|
return covcol[self.target]
|
||||||
|
|
||||||
def taskcol(self, task: Task) -> str:
|
def taskcol(self, task: Task) -> str:
|
||||||
if task == "binary":
|
return taskcol[task][self.target]
|
||||||
return "dartsml_has_rts" if self.target == "darts_mllabels" else "darts_has_rts"
|
|
||||||
elif task == "count":
|
|
||||||
return "dartsml_rts_count" if self.target == "darts_mllabels" else "darts_rts_count"
|
|
||||||
elif task == "density":
|
|
||||||
return "dartsml_rts_density" if self.target == "darts_mllabels" else "darts_rts_density"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid task: {task}")
|
|
||||||
|
|
||||||
def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
|
def _read_member(self, member: L2SourceDataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
|
||||||
if member == "AlphaEarth":
|
if member == "AlphaEarth":
|
||||||
store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level)
|
store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level)
|
||||||
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
||||||
|
|
@ -340,8 +351,9 @@ class DatasetEnsemble:
|
||||||
|
|
||||||
print(f"=== Total number of features in dataset: {stats['total_features']}")
|
print(f"=== Total number of features in dataset: {stats['total_features']}")
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
def create(
|
||||||
def create(self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame:
|
self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r"
|
||||||
|
) -> gpd.GeoDataFrame:
|
||||||
# n: no cache, o: overwrite cache, r: read cache if exists
|
# n: no cache, o: overwrite cache, r: read cache if exists
|
||||||
cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col)
|
cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col)
|
||||||
if cache_mode == "r" and cache_file.exists():
|
if cache_mode == "r" and cache_file.exists():
|
||||||
|
|
@ -425,26 +437,14 @@ class DatasetEnsemble:
|
||||||
print(f"Saved dataset to cache at {cache_file}.")
|
print(f"Saved dataset to cache at {cache_file}.")
|
||||||
yield dataset
|
yield dataset
|
||||||
|
|
||||||
def create_cat_training_dataset(
|
def _cat_and_split(
|
||||||
self, task: Task, device: Literal["cpu", "cuda", "torch"]
|
self, dataset: gpd.GeoDataFrame, task: Task, device: Literal["cpu", "cuda", "torch"]
|
||||||
) -> CategoricalTrainingDataset:
|
) -> CategoricalTrainingDataset:
|
||||||
"""Create a categorical dataset for training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (Task): Task type.
|
|
||||||
device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CategoricalTrainingDataset: The prepared categorical training dataset.
|
|
||||||
|
|
||||||
"""
|
|
||||||
covcol = "dartsml_has_coverage" if self.target == "darts_mllabels" else "darts_has_coverage"
|
|
||||||
dataset = self.create(filter_target_col=covcol)
|
|
||||||
taskcol = self.taskcol(task)
|
taskcol = self.taskcol(task)
|
||||||
|
|
||||||
valid_labels = dataset[taskcol].notna()
|
valid_labels = dataset[taskcol].notna()
|
||||||
|
|
||||||
cols_to_drop = {"geometry", taskcol, covcol}
|
cols_to_drop = {"geometry", taskcol, self.covcol}
|
||||||
cols_to_drop |= {
|
cols_to_drop |= {
|
||||||
col
|
col
|
||||||
for col in dataset.columns
|
for col in dataset.columns
|
||||||
|
|
@ -505,3 +505,19 @@ class DatasetEnsemble:
|
||||||
z=model_labels,
|
z=model_labels,
|
||||||
split=split,
|
split=split,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_cat_training_dataset(
|
||||||
|
self, task: Task, device: Literal["cpu", "cuda", "torch"]
|
||||||
|
) -> CategoricalTrainingDataset:
|
||||||
|
"""Create a categorical dataset for training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (Task): Task type.
|
||||||
|
device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CategoricalTrainingDataset: The prepared categorical training dataset.
|
||||||
|
|
||||||
|
"""
|
||||||
|
dataset = self.create(filter_target_col=self.covcol)
|
||||||
|
return self._cat_and_split(dataset, task, device)
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
import cyclopts
|
import cyclopts
|
||||||
|
|
@ -23,6 +22,7 @@ from xgboost.sklearn import XGBClassifier
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
from entropice.ml.inference import predict_proba
|
from entropice.ml.inference import predict_proba
|
||||||
from entropice.utils.paths import get_cv_results_dir
|
from entropice.utils.paths import get_cv_results_dir
|
||||||
|
from entropice.utils.types import Model, Task
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
@ -50,12 +50,20 @@ _metrics = {
|
||||||
|
|
||||||
|
|
||||||
@cyclopts.Parameter("*")
|
@cyclopts.Parameter("*")
|
||||||
@dataclass
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class CVSettings:
|
class CVSettings:
|
||||||
n_iter: int = 2000
|
n_iter: int = 2000
|
||||||
robust: bool = False
|
robust: bool = False
|
||||||
task: Literal["binary", "count", "density"] = "binary"
|
task: Task = "binary"
|
||||||
model: Literal["espa", "xgboost", "rf", "knn"] = "espa"
|
model: Model = "espa"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class TrainingSettings(DatasetEnsemble, CVSettings):
|
||||||
|
param_grid: dict
|
||||||
|
cv_splits: int
|
||||||
|
metrics: list[str]
|
||||||
|
classes: list[str]
|
||||||
|
|
||||||
|
|
||||||
def _create_clf(
|
def _create_clf(
|
||||||
|
|
@ -141,7 +149,7 @@ def random_cv(
|
||||||
"""Perform random cross-validation on the training dataset.
|
"""Perform random cross-validation on the training dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
grid (Grid): The grid type to use.
|
||||||
level (int): The grid level to use.
|
level (int): The grid level to use.
|
||||||
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
||||||
robust (bool, optional): Whether to use robust training. Defaults to False.
|
robust (bool, optional): Whether to use robust training. Defaults to False.
|
||||||
|
|
@ -202,18 +210,18 @@ def random_cv(
|
||||||
# Store the search settings
|
# Store the search settings
|
||||||
# First convert the param_grid distributions to a serializable format
|
# First convert the param_grid distributions to a serializable format
|
||||||
param_grid_serializable = _serialize_param_grid(param_grid)
|
param_grid_serializable = _serialize_param_grid(param_grid)
|
||||||
combined_settings = {
|
combined_settings = TrainingSettings(
|
||||||
**asdict(settings),
|
**asdict(settings),
|
||||||
**asdict(dataset_ensemble),
|
**asdict(dataset_ensemble),
|
||||||
"param_grid": param_grid_serializable,
|
param_grid=param_grid_serializable,
|
||||||
"cv_splits": cv.get_n_splits(),
|
cv_splits=cv.get_n_splits(),
|
||||||
"metrics": metrics,
|
metrics=metrics,
|
||||||
"classes": training_data.y.labels,
|
classes=training_data.y.labels,
|
||||||
}
|
)
|
||||||
settings_file = results_dir / "search_settings.toml"
|
settings_file = results_dir / "search_settings.toml"
|
||||||
print(f"Storing search settings to {settings_file}")
|
print(f"Storing search settings to {settings_file}")
|
||||||
with open(settings_file, "w") as f:
|
with open(settings_file, "w") as f:
|
||||||
toml.dump({"settings": combined_settings}, f)
|
toml.dump({"settings": asdict(combined_settings)}, f)
|
||||||
|
|
||||||
# Store the best estimator model
|
# Store the best estimator model
|
||||||
best_model_file = results_dir / "best_estimator_model.pkl"
|
best_model_file = results_dir / "best_estimator_model.pkl"
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ from stopuhr import stopwatch
|
||||||
from xdggs.healpix import HealpixInfo
|
from xdggs.healpix import HealpixInfo
|
||||||
|
|
||||||
from entropice.spatial import grids
|
from entropice.spatial import grids
|
||||||
|
from entropice.utils.types import Grid
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
@ -585,7 +586,7 @@ def aggregate_raster_into_grid(
|
||||||
raster: xr.Dataset | Callable[[], xr.Dataset],
|
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||||
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
|
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
|
||||||
aggregations: _Aggregations | Literal["interpolate"],
|
aggregations: _Aggregations | Literal["interpolate"],
|
||||||
grid: Literal["hex", "healpix"],
|
grid: Grid,
|
||||||
level: int,
|
level: int,
|
||||||
n_partitions: int | None = 20,
|
n_partitions: int | None = 20,
|
||||||
concurrent_partitions: int = 5,
|
concurrent_partitions: int = 5,
|
||||||
|
|
@ -599,7 +600,7 @@ def aggregate_raster_into_grid(
|
||||||
If a list of GeoDataFrames is provided, each will be processed as a separate partition.
|
If a list of GeoDataFrames is provided, each will be processed as a separate partition.
|
||||||
No further partitioning will be done and the n_partitions argument will be ignored.
|
No further partitioning will be done and the n_partitions argument will be ignored.
|
||||||
aggregations (_Aggregations | Literal["interpolate"]): The aggregations to perform.
|
aggregations (_Aggregations | Literal["interpolate"]): The aggregations to perform.
|
||||||
grid (Literal["hex", "healpix"]): The type of grid to use.
|
grid (Grid): The type of grid to use.
|
||||||
level (int): The level of the grid.
|
level (int): The level of the grid.
|
||||||
n_partitions (int | None, optional): Number of partitions to divide the grid into. Defaults to 20.
|
n_partitions (int | None, optional): Number of partitions to divide the grid into. Defaults to 20.
|
||||||
concurrent_partitions (int, optional): Maximum number of worker processes when processing partitions.
|
concurrent_partitions (int, optional): Maximum number of worker processes when processing partitions.
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ Date: 09. June 2025
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import cartopy.crs as ccrs
|
import cartopy.crs as ccrs
|
||||||
import cartopy.feature as cfeature
|
import cartopy.feature as cfeature
|
||||||
|
|
@ -26,16 +25,17 @@ from stopuhr import stopwatch
|
||||||
from xdggs.healpix import HealpixInfo
|
from xdggs.healpix import HealpixInfo
|
||||||
|
|
||||||
from entropice.utils.paths import get_grid_file, get_grid_viz_file, watermask_file
|
from entropice.utils.paths import get_grid_file, get_grid_viz_file, watermask_file
|
||||||
|
from entropice.utils.types import Grid
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
||||||
|
|
||||||
def open(grid: Literal["hex", "healpix"], level: int):
|
def open(grid: Grid, level: int):
|
||||||
"""Open a saved grid from parquet file.
|
"""Open a saved grid from parquet file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
grid (Grid): The grid type to use.
|
||||||
level (int): The grid level to use.
|
level (int): The grid level to use.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -47,11 +47,11 @@ def open(grid: Literal["hex", "healpix"], level: int):
|
||||||
return grid
|
return grid
|
||||||
|
|
||||||
|
|
||||||
def get_cell_ids(grid: Literal["hex", "healpix"], level: int):
|
def get_cell_ids(grid: Grid, level: int):
|
||||||
"""Get the cell IDs of a saved grid.
|
"""Get the cell IDs of a saved grid.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
grid (Grid): The grid type to use.
|
||||||
level (int): The grid level to use.
|
level (int): The grid level to use.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -292,7 +292,7 @@ def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure:
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def cli(grid: Literal["hex", "healpix"], level: int):
|
def cli(grid: Grid, level: int):
|
||||||
"""CLI entry point."""
|
"""CLI entry point."""
|
||||||
print(f"Creating {grid} grid at level {level}...")
|
print(f"Creating {grid} grid at level {level}...")
|
||||||
if grid == "hex":
|
if grid == "hex":
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
from entropice.utils.types import Grid, Task
|
||||||
|
|
||||||
DATA_DIR = (
|
DATA_DIR = (
|
||||||
Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice"
|
Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice"
|
||||||
)
|
)
|
||||||
|
|
@ -42,23 +44,23 @@ dartsl2_cov_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.par
|
||||||
darts_ml_training_labels_repo = RTS_LABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
|
darts_ml_training_labels_repo = RTS_LABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
|
||||||
|
|
||||||
|
|
||||||
def _get_gridname(grid: Literal["hex", "healpix"], level: int) -> str:
|
def _get_gridname(grid: Grid, level: int) -> str:
|
||||||
return f"permafrost_{grid}{level}"
|
return f"permafrost_{grid}{level}"
|
||||||
|
|
||||||
|
|
||||||
def get_grid_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
def get_grid_file(grid: Grid, level: int) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
gridfile = GRIDS_DIR / f"{gridname}_grid.parquet"
|
gridfile = GRIDS_DIR / f"{gridname}_grid.parquet"
|
||||||
return gridfile
|
return gridfile
|
||||||
|
|
||||||
|
|
||||||
def get_grid_viz_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
def get_grid_viz_file(grid: Grid, level: int) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
vizfile = FIGURES_DIR / f"{gridname}_grid.png"
|
vizfile = FIGURES_DIR / f"{gridname}_grid.png"
|
||||||
return vizfile
|
return vizfile
|
||||||
|
|
||||||
|
|
||||||
def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int, labels: bool = False) -> Path:
|
def get_darts_rts_file(grid: Grid, level: int, labels: bool = False) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
if labels:
|
if labels:
|
||||||
rtsfile = RTS_LABELS_DIR / f"{gridname}_darts-mllabels.parquet"
|
rtsfile = RTS_LABELS_DIR / f"{gridname}_darts-mllabels.parquet"
|
||||||
|
|
@ -67,13 +69,13 @@ def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int, labels: bool
|
||||||
return rtsfile
|
return rtsfile
|
||||||
|
|
||||||
|
|
||||||
def get_annual_embeddings_file(grid: Literal["hex", "healpix"], level: int, year: int) -> Path:
|
def get_annual_embeddings_file(grid: Grid, level: int, year: int) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
embfile = EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet"
|
embfile = EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet"
|
||||||
return embfile
|
return embfile
|
||||||
|
|
||||||
|
|
||||||
def get_embeddings_store(grid: Literal["hex", "healpix"], level: int) -> Path:
|
def get_embeddings_store(grid: Grid, level: int) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
embstore = EMBEDDINGS_DIR / f"{gridname}_embeddings.zarr"
|
embstore = EMBEDDINGS_DIR / f"{gridname}_embeddings.zarr"
|
||||||
return embstore
|
return embstore
|
||||||
|
|
@ -81,9 +83,9 @@ def get_embeddings_store(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||||
|
|
||||||
def get_era5_stores(
|
def get_era5_stores(
|
||||||
agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily",
|
agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily",
|
||||||
grid: Literal["hex", "healpix"] | None = None,
|
grid: Grid | None = None,
|
||||||
level: int | None = None,
|
level: int | None = None,
|
||||||
):
|
) -> Path:
|
||||||
if grid is None or level is None:
|
if grid is None or level is None:
|
||||||
(ERA5_DIR / "intermediate").mkdir(parents=True, exist_ok=True)
|
(ERA5_DIR / "intermediate").mkdir(parents=True, exist_ok=True)
|
||||||
return ERA5_DIR / "intermediate" / f"{agg}_climate.zarr"
|
return ERA5_DIR / "intermediate" / f"{agg}_climate.zarr"
|
||||||
|
|
@ -94,9 +96,9 @@ def get_era5_stores(
|
||||||
|
|
||||||
|
|
||||||
def get_arcticdem_stores(
|
def get_arcticdem_stores(
|
||||||
grid: Literal["hex", "healpix"] | None = None,
|
grid: Grid | None = None,
|
||||||
level: int | None = None,
|
level: int | None = None,
|
||||||
):
|
) -> Path:
|
||||||
if grid is None or level is None:
|
if grid is None or level is None:
|
||||||
return DATA_DIR / "arcticdem32m.icechunk.zarr"
|
return DATA_DIR / "arcticdem32m.icechunk.zarr"
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
|
|
@ -104,7 +106,7 @@ def get_arcticdem_stores(
|
||||||
return aligned_path
|
return aligned_path
|
||||||
|
|
||||||
|
|
||||||
def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
def get_train_dataset_file(grid: Grid, level: int) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
dataset_file = TRAINING_DIR / f"{gridname}_train_dataset.parquet"
|
dataset_file = TRAINING_DIR / f"{gridname}_train_dataset.parquet"
|
||||||
return dataset_file
|
return dataset_file
|
||||||
|
|
@ -122,9 +124,9 @@ def get_dataset_cache(eid: str, subset: str | None = None, batch: tuple[int, int
|
||||||
|
|
||||||
def get_cv_results_dir(
|
def get_cv_results_dir(
|
||||||
name: str,
|
name: str,
|
||||||
grid: Literal["hex", "healpix"],
|
grid: Grid,
|
||||||
level: int,
|
level: int,
|
||||||
task: Literal["binary", "count", "density"],
|
task: Task,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
|
|
|
||||||
11
src/entropice/utils/types.py
Normal file
11
src/entropice/utils/types.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
"""Shared types used across the entropice codebase."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
type Grid = Literal["hex", "healpix"]
|
||||||
|
type GridLevel = Literal["hex3", "hex4", "hex5", "hex6", "healpix6", "healpix7", "healpix8", "healpix9", "healpix10"]
|
||||||
|
type TargetDataset = Literal["darts_rts", "darts_mllabels"]
|
||||||
|
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
|
||||||
|
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
|
||||||
|
type Task = Literal["binary", "count", "density"]
|
||||||
|
type Model = Literal["espa", "xgboost", "rf", "knn"]
|
||||||
Loading…
Add table
Add a link
Reference in a new issue