Unfiy the training scripts and add SHAP
This commit is contained in:
parent
073502c51d
commit
3ce6b6e867
38 changed files with 5876 additions and 681 deletions
14
autogluon-config.toml
Normal file
14
autogluon-config.toml
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
[tool.entropice-autogluon]
|
||||||
|
time-limit = 60
|
||||||
|
presets = "medium"
|
||||||
|
target = "darts_v1"
|
||||||
|
task = "density"
|
||||||
|
experiment = "tobi-tests"
|
||||||
|
|
||||||
|
grid = "hex"
|
||||||
|
level = 5
|
||||||
|
members = ["ERA5-shoulder", "ArcticDEM"]
|
||||||
|
|
||||||
|
[tool.entropice-autogluon.dimension-filters]
|
||||||
|
# ERA5-shoulder = { aggregations = "median" }
|
||||||
|
ArcticDEM = { aggregations = "median" }
|
||||||
19
hpsearchcv-config.toml
Normal file
19
hpsearchcv-config.toml
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
[tool.entropice-hpsearchcv]
|
||||||
|
n-iter = 5
|
||||||
|
target = "darts_v1"
|
||||||
|
task = "binary"
|
||||||
|
splitter = "kfold"
|
||||||
|
# model = "xgboost"
|
||||||
|
# model = "rf" SHAP error
|
||||||
|
model = "espa"
|
||||||
|
experiment = "tobi-tests"
|
||||||
|
scaler = "standard" # They dont work because of Array API
|
||||||
|
normalize = true
|
||||||
|
|
||||||
|
grid = "hex"
|
||||||
|
level = 5
|
||||||
|
members = ["ERA5-shoulder", "ArcticDEM"]
|
||||||
|
|
||||||
|
[tool.entropice-hpsearchcv.dimension-filters]
|
||||||
|
# ERA5-shoulder = { aggregations = "median" }
|
||||||
|
ArcticDEM = { aggregations = "median" }
|
||||||
|
|
@ -67,8 +67,10 @@ dependencies = [
|
||||||
"ruff>=0.14.11,<0.15",
|
"ruff>=0.14.11,<0.15",
|
||||||
"pandas-stubs>=2.3.3.251201,<3",
|
"pandas-stubs>=2.3.3.251201,<3",
|
||||||
"pytest>=9.0.2,<10",
|
"pytest>=9.0.2,<10",
|
||||||
"autogluon-tabular[all,mitra]>=1.5.0",
|
"autogluon-tabular[all,mitra,realmlp,interpret,fastai,tabm,tabpfn,tabdpt,tabpfnmix,tabicl,skew,imodels]>=1.5.0",
|
||||||
"shap>=0.50.0,<0.51", "h5py>=3.15.1,<4",
|
"shap>=0.50.0,<0.51",
|
||||||
|
"h5py>=3.15.1,<4",
|
||||||
|
"pydantic>=2.12.5,<3",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
@ -77,8 +79,8 @@ darts = "entropice.ingest.darts:cli"
|
||||||
alpha-earth = "entropice.ingest.alphaearth:main"
|
alpha-earth = "entropice.ingest.alphaearth:main"
|
||||||
era5 = "entropice.ingest.era5:cli"
|
era5 = "entropice.ingest.era5:cli"
|
||||||
arcticdem = "entropice.ingest.arcticdem:cli"
|
arcticdem = "entropice.ingest.arcticdem:cli"
|
||||||
train = "entropice.ml.training:cli"
|
train = "entropice.ml.hpsearchcv:cli"
|
||||||
autogluon = "entropice.ml.autogluon_training:cli"
|
autogluon = "entropice.ml.autogluon:cli"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ Pages:
|
||||||
- Overview: List of available result directories with some summary statistics.
|
- Overview: List of available result directories with some summary statistics.
|
||||||
- Training Data: Visualization of training data distributions.
|
- Training Data: Visualization of training data distributions.
|
||||||
- Training Results Analysis: Analysis of training results and model performance.
|
- Training Results Analysis: Analysis of training results and model performance.
|
||||||
|
- Experiment Analysis: Compare multiple training runs within an experiment.
|
||||||
- AutoGluon Analysis: Analysis of AutoGluon training results with SHAP visualizations.
|
- AutoGluon Analysis: Analysis of AutoGluon training results with SHAP visualizations.
|
||||||
- Model State: Visualization of model state and features.
|
- Model State: Visualization of model state and features.
|
||||||
- Inference: Visualization of inference results.
|
- Inference: Visualization of inference results.
|
||||||
|
|
@ -13,6 +14,7 @@ Pages:
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from entropice.dashboard.views.dataset_page import render_dataset_page
|
from entropice.dashboard.views.dataset_page import render_dataset_page
|
||||||
|
from entropice.dashboard.views.experiment_analysis_page import render_experiment_analysis_page
|
||||||
from entropice.dashboard.views.inference_page import render_inference_page
|
from entropice.dashboard.views.inference_page import render_inference_page
|
||||||
from entropice.dashboard.views.model_state_page import render_model_state_page
|
from entropice.dashboard.views.model_state_page import render_model_state_page
|
||||||
from entropice.dashboard.views.overview_page import render_overview_page
|
from entropice.dashboard.views.overview_page import render_overview_page
|
||||||
|
|
@ -27,6 +29,7 @@ def main():
|
||||||
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
|
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
|
||||||
data_page = st.Page(render_dataset_page, title="Dataset", icon="📊")
|
data_page = st.Page(render_dataset_page, title="Dataset", icon="📊")
|
||||||
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
||||||
|
experiment_analysis_page = st.Page(render_experiment_analysis_page, title="Experiment Analysis", icon="🔬")
|
||||||
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
|
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
|
||||||
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
|
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
|
||||||
|
|
||||||
|
|
@ -34,7 +37,7 @@ def main():
|
||||||
{
|
{
|
||||||
"Overview": [overview_page],
|
"Overview": [overview_page],
|
||||||
"Data": [data_page],
|
"Data": [data_page],
|
||||||
"Experiments": [training_analysis_page, model_state_page],
|
"Experiments": [training_analysis_page, experiment_analysis_page, model_state_page],
|
||||||
"Inference": [inference_page],
|
"Inference": [inference_page],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
0
src/entropice/dashboard/plots/__init__.py
Normal file
0
src/entropice/dashboard/plots/__init__.py
Normal file
619
src/entropice/dashboard/plots/correlations.py
Normal file
619
src/entropice/dashboard/plots/correlations.py
Normal file
|
|
@ -0,0 +1,619 @@
|
||||||
|
"""Plots for cross-dataset correlation and similarity analysis."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
import seaborn as sns
|
||||||
|
from plotly.subplots import make_subplots
|
||||||
|
from scipy.cluster.hierarchy import dendrogram, linkage
|
||||||
|
from scipy.spatial.distance import squareform
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
|
||||||
|
|
||||||
|
def select_top_variable_features(
|
||||||
|
data_dict: dict[str, pd.Series],
|
||||||
|
n_features: int = 500,
|
||||||
|
method: Literal["variance", "iqr", "cv"] = "variance",
|
||||||
|
) -> dict[str, pd.Series]:
|
||||||
|
"""Select the top N most variable features from a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict: Dictionary mapping variable names to pandas Series
|
||||||
|
n_features: Number of features to select
|
||||||
|
method: Method to measure variability ('variance', 'iqr', or 'cv')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered dictionary with only the most variable features
|
||||||
|
|
||||||
|
"""
|
||||||
|
if len(data_dict) <= n_features:
|
||||||
|
return data_dict
|
||||||
|
|
||||||
|
# Calculate variability metric for each feature
|
||||||
|
variability_scores = {}
|
||||||
|
for var_name, values in data_dict.items():
|
||||||
|
clean_values = values.dropna()
|
||||||
|
if len(clean_values) == 0:
|
||||||
|
variability_scores[var_name] = 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
if method == "variance":
|
||||||
|
variability_scores[var_name] = clean_values.var()
|
||||||
|
elif method == "iqr":
|
||||||
|
variability_scores[var_name] = clean_values.quantile(0.75) - clean_values.quantile(0.25)
|
||||||
|
elif method == "cv":
|
||||||
|
mean_val = clean_values.mean()
|
||||||
|
if abs(mean_val) > 1e-10:
|
||||||
|
variability_scores[var_name] = clean_values.std() / abs(mean_val)
|
||||||
|
else:
|
||||||
|
variability_scores[var_name] = 0
|
||||||
|
|
||||||
|
# Sort by variability and select top N
|
||||||
|
sorted_vars = sorted(variability_scores.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
top_vars = [var_name for var_name, _ in sorted_vars[:n_features]]
|
||||||
|
|
||||||
|
return {k: v for k, v in data_dict.items() if k in top_vars}
|
||||||
|
|
||||||
|
|
||||||
|
def create_matplotlib_correlation_heatmap(
|
||||||
|
data_dict: dict[str, pd.Series],
|
||||||
|
method: Literal["pearson", "kendall", "spearman"] = "pearson",
|
||||||
|
cluster: bool = False,
|
||||||
|
max_labels: int = 100,
|
||||||
|
) -> plt.Figure:
|
||||||
|
"""Create a correlation heatmap using matplotlib for large datasets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict: Dictionary mapping variable names to pandas Series with cell_ids as index
|
||||||
|
method: Correlation method ('pearson', 'spearman', or 'kendall')
|
||||||
|
cluster: Whether to reorder variables by hierarchical clustering
|
||||||
|
max_labels: Maximum number of labels to show on axes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Matplotlib Figure with correlation heatmap
|
||||||
|
|
||||||
|
"""
|
||||||
|
if len(data_dict) < 2:
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 6))
|
||||||
|
ax.text(0.5, 0.5, "Need at least 2 variables for correlation analysis", ha="center", va="center", fontsize=12)
|
||||||
|
ax.axis("off")
|
||||||
|
return fig
|
||||||
|
|
||||||
|
# Create DataFrame from all variables
|
||||||
|
df = pd.DataFrame(data_dict)
|
||||||
|
|
||||||
|
# Drop rows with any NaN to get valid correlations
|
||||||
|
df_clean = df.dropna()
|
||||||
|
|
||||||
|
if len(df_clean) == 0:
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 6))
|
||||||
|
ax.text(0.5, 0.5, "No overlapping non-null data between variables", ha="center", va="center", fontsize=12)
|
||||||
|
ax.axis("off")
|
||||||
|
return fig
|
||||||
|
|
||||||
|
# Subsample if too many rows for performance
|
||||||
|
if len(df_clean) > 50000:
|
||||||
|
df_clean = df_clean.sample(n=50000, random_state=42)
|
||||||
|
|
||||||
|
# Calculate correlation matrix
|
||||||
|
corr = df_clean.corr(method=method)
|
||||||
|
|
||||||
|
# Apply hierarchical clustering if requested
|
||||||
|
if cluster and len(corr) > 2:
|
||||||
|
# Use correlation distance for clustering
|
||||||
|
corr_dist = 1 - np.abs(corr.values)
|
||||||
|
np.fill_diagonal(corr_dist, 0)
|
||||||
|
condensed_dist = squareform(corr_dist, checks=False)
|
||||||
|
linkage_matrix = linkage(condensed_dist, method="average")
|
||||||
|
dendro = dendrogram(linkage_matrix, no_plot=True)
|
||||||
|
cluster_order = dendro["leaves"]
|
||||||
|
|
||||||
|
# Reorder correlation matrix
|
||||||
|
corr = corr.iloc[cluster_order, cluster_order]
|
||||||
|
|
||||||
|
# Determine figure size based on number of variables
|
||||||
|
n_vars = len(corr)
|
||||||
|
fig_size = max(10, min(n_vars * 0.3, 50))
|
||||||
|
|
||||||
|
# Create figure
|
||||||
|
fig, ax = plt.subplots(figsize=(fig_size, fig_size))
|
||||||
|
|
||||||
|
# Create heatmap using seaborn
|
||||||
|
sns.heatmap(
|
||||||
|
corr,
|
||||||
|
cmap="RdBu_r",
|
||||||
|
center=0,
|
||||||
|
vmin=-1,
|
||||||
|
vmax=1,
|
||||||
|
square=True,
|
||||||
|
linewidths=0.5 if n_vars < 50 else 0,
|
||||||
|
cbar_kws={"shrink": 0.8, "label": f"{method.title()} Correlation"},
|
||||||
|
ax=ax,
|
||||||
|
xticklabels=n_vars <= max_labels,
|
||||||
|
yticklabels=n_vars <= max_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set title
|
||||||
|
title = f"Correlation Matrix ({method.title()}, {n_vars} variables)"
|
||||||
|
if cluster:
|
||||||
|
title += " - Hierarchically Clustered"
|
||||||
|
ax.set_title(title, fontsize=14, pad=20)
|
||||||
|
|
||||||
|
# Rotate labels if shown
|
||||||
|
if n_vars <= max_labels:
|
||||||
|
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", fontsize=8)
|
||||||
|
plt.setp(ax.get_yticklabels(), rotation=0, fontsize=8)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_full_correlation_heatmap(
|
||||||
|
data_dict: dict[str, pd.Series],
|
||||||
|
method: Literal["pearson", "kendall", "spearman"] = "pearson",
|
||||||
|
cluster: bool = False,
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create a correlation heatmap for all variables across all datasets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict: Dictionary mapping variable names to pandas Series with cell_ids as index
|
||||||
|
method: Correlation method ('pearson', 'spearman', or 'kendall')
|
||||||
|
cluster: Whether to reorder variables by hierarchical clustering
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly Figure with correlation heatmap
|
||||||
|
|
||||||
|
"""
|
||||||
|
if len(data_dict) < 2:
|
||||||
|
return go.Figure().add_annotation(
|
||||||
|
text="Need at least 2 variables for correlation analysis",
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
showarrow=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create DataFrame from all variables
|
||||||
|
df = pd.DataFrame(data_dict)
|
||||||
|
|
||||||
|
# Drop rows with any NaN to get valid correlations
|
||||||
|
df_clean = df.dropna()
|
||||||
|
|
||||||
|
if len(df_clean) == 0:
|
||||||
|
return go.Figure().add_annotation(
|
||||||
|
text="No overlapping non-null data between variables",
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
showarrow=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subsample if too many rows for performance
|
||||||
|
if len(df_clean) > 50000:
|
||||||
|
df_clean = df_clean.sample(n=50000, random_state=42)
|
||||||
|
|
||||||
|
# Calculate correlation matrix
|
||||||
|
corr = df_clean.corr(method=method)
|
||||||
|
|
||||||
|
# Apply hierarchical clustering if requested
|
||||||
|
if cluster and len(corr) > 2:
|
||||||
|
# Use correlation distance for clustering
|
||||||
|
corr_dist = 1 - np.abs(corr.values)
|
||||||
|
np.fill_diagonal(corr_dist, 0) # Ensure diagonal is 0
|
||||||
|
condensed_dist = squareform(corr_dist, checks=False)
|
||||||
|
linkage_matrix = linkage(condensed_dist, method="average")
|
||||||
|
dendro = dendrogram(linkage_matrix, no_plot=True)
|
||||||
|
cluster_order = dendro["leaves"]
|
||||||
|
|
||||||
|
# Reorder correlation matrix
|
||||||
|
corr = corr.iloc[cluster_order, cluster_order]
|
||||||
|
|
||||||
|
# Create heatmap
|
||||||
|
fig = go.Figure(
|
||||||
|
data=go.Heatmap(
|
||||||
|
z=corr.values,
|
||||||
|
x=corr.columns,
|
||||||
|
y=corr.index,
|
||||||
|
colorscale="RdBu_r",
|
||||||
|
zmid=0,
|
||||||
|
zmin=-1,
|
||||||
|
zmax=1,
|
||||||
|
text=np.round(corr.values, 2),
|
||||||
|
texttemplate="%{text}",
|
||||||
|
textfont={"size": 8},
|
||||||
|
colorbar={"title": f"{method.title()}<br>Correlation"},
|
||||||
|
hovertemplate=f"<b>%{{y}}</b> vs <b>%{{x}}</b><br>{method.title()} correlation: %{{z:.3f}}<extra></extra>",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
title = f"Cross-Dataset Variable Correlations ({method.title()})"
|
||||||
|
if cluster:
|
||||||
|
title += " - Hierarchically Clustered"
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title=title,
|
||||||
|
height=max(600, len(corr) * 15),
|
||||||
|
width=max(800, len(corr) * 15),
|
||||||
|
xaxis={"side": "bottom", "tickangle": 45},
|
||||||
|
yaxis={"tickangle": 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_pca_biplot(
|
||||||
|
data_dict: dict[str, pd.Series],
|
||||||
|
n_components: int = 2,
|
||||||
|
show_loadings: bool = True,
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create a PCA biplot showing feature relationships in principal component space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict: Dictionary mapping variable names to pandas Series with cell_ids as index
|
||||||
|
n_components: Number of principal components to compute
|
||||||
|
show_loadings: Whether to show loading vectors
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly Figure with PCA biplot
|
||||||
|
|
||||||
|
"""
|
||||||
|
if len(data_dict) < 2:
|
||||||
|
return go.Figure().add_annotation(
|
||||||
|
text="Need at least 2 variables for PCA analysis",
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
showarrow=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create DataFrame and clean
|
||||||
|
df = pd.DataFrame(data_dict).dropna()
|
||||||
|
|
||||||
|
if len(df) < 10:
|
||||||
|
return go.Figure().add_annotation(
|
||||||
|
text="Not enough overlapping data for PCA analysis",
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
showarrow=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subsample if needed
|
||||||
|
if len(df) > 50000:
|
||||||
|
df = df.sample(n=50000, random_state=42)
|
||||||
|
|
||||||
|
# Standardize features
|
||||||
|
scaler = StandardScaler()
|
||||||
|
scaled_data = scaler.fit_transform(df)
|
||||||
|
|
||||||
|
# Perform PCA
|
||||||
|
n_components = min(n_components, len(data_dict), len(df))
|
||||||
|
pca = PCA(n_components=n_components)
|
||||||
|
pca_result = pca.fit_transform(scaled_data)
|
||||||
|
|
||||||
|
# Create figure
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
if n_components >= 2:
|
||||||
|
# Scatter plot of observations (subsampled for visibility)
|
||||||
|
sample_size = min(5000, len(pca_result))
|
||||||
|
rng = np.random.default_rng(42)
|
||||||
|
indices = rng.choice(len(pca_result), sample_size, replace=False)
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=pca_result[indices, 0],
|
||||||
|
y=pca_result[indices, 1],
|
||||||
|
mode="markers",
|
||||||
|
marker={"size": 3, "opacity": 0.3, "color": "lightblue"},
|
||||||
|
name="Observations",
|
||||||
|
showlegend=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add loading vectors if requested
|
||||||
|
if show_loadings:
|
||||||
|
loadings = pca.components_.T * np.sqrt(pca.explained_variance_)
|
||||||
|
|
||||||
|
# Scale loadings for visibility
|
||||||
|
max_loading = np.abs(loadings[:, :2]).max()
|
||||||
|
max_data = max(np.abs(pca_result[:, :2]).max(axis=0))
|
||||||
|
scale = max_data / max_loading * 0.8
|
||||||
|
|
||||||
|
for i, var_name in enumerate(df.columns):
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=[0, loadings[i, 0] * scale],
|
||||||
|
y=[0, loadings[i, 1] * scale],
|
||||||
|
mode="lines+text",
|
||||||
|
line={"color": "red", "width": 2},
|
||||||
|
text=["", var_name],
|
||||||
|
textposition="top center",
|
||||||
|
textfont={"size": 10, "color": "darkred"},
|
||||||
|
name=var_name,
|
||||||
|
showlegend=False,
|
||||||
|
hovertemplate=f"<b>{var_name}</b><br>"
|
||||||
|
+ f"PC1 loading: {loadings[i, 0]:.3f}<br>"
|
||||||
|
+ f"PC2 loading: {loadings[i, 1]:.3f}<extra></extra>",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
var_exp = pca.explained_variance_ratio_
|
||||||
|
fig.update_xaxes(title=f"PC1 ({var_exp[0] * 100:.1f}% variance)")
|
||||||
|
fig.update_yaxes(title=f"PC2 ({var_exp[1] * 100:.1f}% variance)")
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title="PCA Biplot - Feature Relationships in Principal Component Space",
|
||||||
|
height=700,
|
||||||
|
showlegend=show_loadings,
|
||||||
|
hovermode="closest",
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_dendrogram_plot(
|
||||||
|
data_dict: dict[str, pd.Series],
|
||||||
|
method: Literal["pearson", "kendall", "spearman"] = "pearson",
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create a dendrogram showing hierarchical clustering of variables based on correlation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict: Dictionary mapping variable names to pandas Series with cell_ids as index
|
||||||
|
method: Correlation method to use for distance calculation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly Figure with dendrogram
|
||||||
|
|
||||||
|
"""
|
||||||
|
if len(data_dict) < 2:
|
||||||
|
return go.Figure().add_annotation(
|
||||||
|
text="Need at least 2 variables for clustering",
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
showarrow=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create DataFrame and clean
|
||||||
|
df = pd.DataFrame(data_dict).dropna()
|
||||||
|
|
||||||
|
if len(df) < 10:
|
||||||
|
return go.Figure().add_annotation(
|
||||||
|
text="Not enough overlapping data for clustering",
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
showarrow=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subsample if needed
|
||||||
|
if len(df) > 50000:
|
||||||
|
df = df.sample(n=50000, random_state=42)
|
||||||
|
|
||||||
|
# Calculate correlation and convert to distance
|
||||||
|
corr = df.corr(method=method)
|
||||||
|
corr_dist = 1 - np.abs(corr.values)
|
||||||
|
np.fill_diagonal(corr_dist, 0)
|
||||||
|
|
||||||
|
# Perform hierarchical clustering
|
||||||
|
condensed_dist = squareform(corr_dist, checks=False)
|
||||||
|
linkage_matrix = linkage(condensed_dist, method="average")
|
||||||
|
|
||||||
|
# Create dendrogram
|
||||||
|
dendro = dendrogram(linkage_matrix, labels=list(df.columns), no_plot=True)
|
||||||
|
|
||||||
|
# Extract dendrogram data
|
||||||
|
icoord = np.array(dendro["icoord"])
|
||||||
|
dcoord = np.array(dendro["dcoord"])
|
||||||
|
|
||||||
|
# Create plotly figure
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
# Add dendrogram lines
|
||||||
|
for i in range(len(icoord)):
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=icoord[i],
|
||||||
|
y=dcoord[i],
|
||||||
|
mode="lines",
|
||||||
|
line={"color": "black", "width": 1.5},
|
||||||
|
hoverinfo="skip",
|
||||||
|
showlegend=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add labels
|
||||||
|
labels = dendro["ivl"]
|
||||||
|
label_pos = np.arange(5, len(labels) * 10 + 5, 10)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title=f"Hierarchical Clustering of Variables (based on {method.title()} correlation)",
|
||||||
|
xaxis={
|
||||||
|
"title": "Variables",
|
||||||
|
"tickmode": "array",
|
||||||
|
"tickvals": label_pos,
|
||||||
|
"ticktext": labels,
|
||||||
|
"tickangle": 45,
|
||||||
|
},
|
||||||
|
yaxis={"title": "Distance (1 - |correlation|)"},
|
||||||
|
height=600,
|
||||||
|
showlegend=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_mutual_information_matrix(data_dict: dict[str, pd.Series]) -> go.Figure:
|
||||||
|
"""Create a mutual information matrix to capture non-linear relationships.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict: Dictionary mapping variable names to pandas Series with cell_ids as index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly Figure with mutual information heatmap
|
||||||
|
|
||||||
|
"""
|
||||||
|
from sklearn.feature_selection import mutual_info_regression
|
||||||
|
|
||||||
|
if len(data_dict) < 2:
|
||||||
|
return go.Figure().add_annotation(
|
||||||
|
text="Need at least 2 variables for mutual information analysis",
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
showarrow=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create DataFrame and clean
|
||||||
|
df = pd.DataFrame(data_dict).dropna()
|
||||||
|
|
||||||
|
if len(df) < 100:
|
||||||
|
return go.Figure().add_annotation(
|
||||||
|
text="Not enough overlapping data for mutual information analysis",
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
showarrow=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subsample if needed
|
||||||
|
if len(df) > 10000: # MI is computationally expensive
|
||||||
|
df = df.sample(n=10000, random_state=42)
|
||||||
|
|
||||||
|
# Calculate pairwise mutual information
|
||||||
|
n_vars = len(df.columns)
|
||||||
|
mi_matrix = np.zeros((n_vars, n_vars))
|
||||||
|
|
||||||
|
for i, col_i in enumerate(df.columns):
|
||||||
|
for j, col_j in enumerate(df.columns):
|
||||||
|
if i == j:
|
||||||
|
mi_matrix[i, j] = 1.0 # Self-information (normalized)
|
||||||
|
elif i < j:
|
||||||
|
# Calculate MI
|
||||||
|
X = df[col_i].to_numpy().reshape(-1, 1) # noqa: N806
|
||||||
|
y = df[col_j].to_numpy()
|
||||||
|
mi = mutual_info_regression(X, y, random_state=42)[0]
|
||||||
|
|
||||||
|
# Normalize by entropy (approximate)
|
||||||
|
mi_norm = mi / (np.std(y) * np.std(X.flatten()) + 1e-10)
|
||||||
|
mi_matrix[i, j] = mi_norm
|
||||||
|
mi_matrix[j, i] = mi_norm
|
||||||
|
|
||||||
|
# Create heatmap
|
||||||
|
fig = go.Figure(
|
||||||
|
data=go.Heatmap(
|
||||||
|
z=mi_matrix,
|
||||||
|
x=list(df.columns),
|
||||||
|
y=list(df.columns),
|
||||||
|
colorscale="YlOrRd",
|
||||||
|
zmin=0,
|
||||||
|
text=np.round(mi_matrix, 3),
|
||||||
|
texttemplate="%{text}",
|
||||||
|
textfont={"size": 8},
|
||||||
|
colorbar={"title": "Normalized<br>Mutual Info"},
|
||||||
|
hovertemplate="<b>%{y}</b> vs <b>%{x}</b><br>Mutual Information: %{z:.3f}<extra></extra>",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title="Mutual Information Matrix (captures non-linear relationships)",
|
||||||
|
height=max(600, n_vars * 15),
|
||||||
|
width=max(800, n_vars * 15),
|
||||||
|
xaxis={"side": "bottom", "tickangle": 45},
|
||||||
|
yaxis={"tickangle": 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_feature_variance_plot(data_dict: dict[str, pd.Series]) -> go.Figure:
|
||||||
|
"""Create a bar plot showing variance/spread of each feature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict: Dictionary mapping variable names to pandas Series with cell_ids as index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly Figure with variance bar plot
|
||||||
|
|
||||||
|
"""
|
||||||
|
if len(data_dict) == 0:
|
||||||
|
return go.Figure().add_annotation(
|
||||||
|
text="No variables provided",
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
showarrow=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate statistics for each variable
|
||||||
|
stats_data = []
|
||||||
|
for var_name, values in data_dict.items():
|
||||||
|
clean_values = values.dropna()
|
||||||
|
if len(clean_values) > 0:
|
||||||
|
stats_data.append(
|
||||||
|
{
|
||||||
|
"Variable": var_name,
|
||||||
|
"Std Dev": clean_values.std(),
|
||||||
|
"Variance": clean_values.var(),
|
||||||
|
"IQR": clean_values.quantile(0.75) - clean_values.quantile(0.25),
|
||||||
|
"Range": clean_values.max() - clean_values.min(),
|
||||||
|
"CV": clean_values.std() / (abs(clean_values.mean()) + 1e-10),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
stats_df = pd.DataFrame(stats_data).sort_values("Variance", ascending=False)
|
||||||
|
|
||||||
|
# Create subplots for different variance metrics
|
||||||
|
fig = make_subplots(
|
||||||
|
rows=2,
|
||||||
|
cols=2,
|
||||||
|
subplot_titles=["Standard Deviation", "Interquartile Range (IQR)", "Range", "Coefficient of Variation"],
|
||||||
|
vertical_spacing=0.12,
|
||||||
|
horizontal_spacing=0.10,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = [
|
||||||
|
("Std Dev", 1, 1),
|
||||||
|
("IQR", 1, 2),
|
||||||
|
("Range", 2, 1),
|
||||||
|
("CV", 2, 2),
|
||||||
|
]
|
||||||
|
|
||||||
|
for metric, row, col in metrics:
|
||||||
|
fig.add_trace(
|
||||||
|
go.Bar(
|
||||||
|
x=stats_df["Variable"],
|
||||||
|
y=stats_df[metric],
|
||||||
|
name=metric,
|
||||||
|
marker={"color": stats_df[metric], "colorscale": "Viridis"},
|
||||||
|
showlegend=False,
|
||||||
|
hovertemplate=f"<b>%{{x}}</b><br>{metric}: %{{y:.3f}}<extra></extra>",
|
||||||
|
),
|
||||||
|
row=row,
|
||||||
|
col=col,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_xaxes(tickangle=45, row=row, col=col)
|
||||||
|
fig.update_yaxes(title_text=metric, row=row, col=col)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title="Feature Variance and Spread Metrics",
|
||||||
|
height=800,
|
||||||
|
showlegend=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
972
src/entropice/dashboard/plots/experiment_comparison.py
Normal file
972
src/entropice/dashboard/plots/experiment_comparison.py
Normal file
|
|
@ -0,0 +1,972 @@
|
||||||
|
"""Plots for experiment comparison and analysis."""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import plotly.express as px
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
|
||||||
|
from entropice.dashboard.utils.colors import get_palette
|
||||||
|
from entropice.dashboard.utils.formatters import format_metric_name
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import geopandas as gpd
|
||||||
|
import pydeck as pdk
|
||||||
|
|
||||||
|
|
||||||
|
def create_grid_level_comparison_plot(
|
||||||
|
results_df: pd.DataFrame,
|
||||||
|
metric: str,
|
||||||
|
split: str = "test",
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create a plot comparing model performance across grid levels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results_df: DataFrame with experiment results including grid, level, model, and metrics
|
||||||
|
metric: Metric to compare (e.g., 'f1', 'accuracy', 'r2')
|
||||||
|
split: Data split to show ('train', 'test', or 'combined')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly figure showing performance across grid levels
|
||||||
|
|
||||||
|
"""
|
||||||
|
metric_col = f"{split}_{metric}"
|
||||||
|
|
||||||
|
if metric_col not in results_df.columns:
|
||||||
|
raise ValueError(f"Metric {metric_col} not found in results")
|
||||||
|
|
||||||
|
# Create grid_level column for grouping
|
||||||
|
results_df = results_df.copy()
|
||||||
|
results_df["grid_level"] = results_df["grid"] + "_" + results_df["level"].astype(str)
|
||||||
|
|
||||||
|
# Define the proper order for grid levels by resolution
|
||||||
|
grid_level_order = [
|
||||||
|
"hex_3",
|
||||||
|
"healpix_6",
|
||||||
|
"healpix_7",
|
||||||
|
"hex_4",
|
||||||
|
"healpix_8",
|
||||||
|
"hex_5",
|
||||||
|
"healpix_9",
|
||||||
|
"healpix_10",
|
||||||
|
"hex_6",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create display labels for grid levels
|
||||||
|
grid_level_display = {
|
||||||
|
"hex_3": "Hex-3",
|
||||||
|
"healpix_6": "Healpix-6",
|
||||||
|
"healpix_7": "Healpix-7",
|
||||||
|
"hex_4": "Hex-4",
|
||||||
|
"healpix_8": "Healpix-8",
|
||||||
|
"hex_5": "Hex-5",
|
||||||
|
"healpix_9": "Healpix-9",
|
||||||
|
"healpix_10": "Healpix-10",
|
||||||
|
"hex_6": "Hex-6",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add display column
|
||||||
|
results_df["grid_level_display"] = results_df["grid_level"].map(grid_level_display)
|
||||||
|
|
||||||
|
# Create color map for target datasets (use 2nd color from 5-color palette for more saturation)
|
||||||
|
unique_targets = results_df["target"].unique()
|
||||||
|
target_colors = {target: get_palette(target, 5)[1] for target in unique_targets}
|
||||||
|
|
||||||
|
# Create symbol map for models
|
||||||
|
model_symbols = {
|
||||||
|
"espa": "circle",
|
||||||
|
"xgboost": "square",
|
||||||
|
"rf": "diamond",
|
||||||
|
"knn": "cross",
|
||||||
|
"ensemble": "star",
|
||||||
|
"autogluon": "star",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add a combined column for hover information
|
||||||
|
results_df["model_display"] = results_df["model"].str.upper()
|
||||||
|
|
||||||
|
# Create box plot without individual points first
|
||||||
|
fig = px.box(
|
||||||
|
results_df,
|
||||||
|
x="grid_level_display",
|
||||||
|
y=metric_col,
|
||||||
|
color="target",
|
||||||
|
facet_col="task",
|
||||||
|
points=False, # We'll add points separately with symbols
|
||||||
|
title=f"{format_metric_name(metric)} by Grid Level ({split.capitalize()} Set)",
|
||||||
|
labels={
|
||||||
|
metric_col: format_metric_name(metric),
|
||||||
|
"grid_level_display": "Grid Level",
|
||||||
|
"target": "Target Dataset",
|
||||||
|
},
|
||||||
|
color_discrete_map=target_colors,
|
||||||
|
category_orders={"grid_level_display": [grid_level_display[gl] for gl in grid_level_order]},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add scatter points with model-specific symbols
|
||||||
|
# Group by task for faceting
|
||||||
|
if "task" in results_df.columns:
|
||||||
|
unique_tasks = results_df["task"].unique()
|
||||||
|
else:
|
||||||
|
unique_tasks = [None]
|
||||||
|
|
||||||
|
# Add scatter traces and line traces for each target-model combination
|
||||||
|
for target in unique_targets:
|
||||||
|
target_data = results_df[results_df["target"] == target]
|
||||||
|
for model in target_data["model"].unique():
|
||||||
|
model_data = target_data[target_data["model"] == model]
|
||||||
|
symbol = model_symbols.get(model, "circle")
|
||||||
|
|
||||||
|
# Add scatter trace for each task facet
|
||||||
|
for task_idx, task in enumerate(unique_tasks):
|
||||||
|
if task is not None:
|
||||||
|
task_data = model_data[model_data["task"] == task]
|
||||||
|
else:
|
||||||
|
task_data = model_data
|
||||||
|
|
||||||
|
if len(task_data) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Sort by grid_level order for proper line connections
|
||||||
|
task_data["grid_level_order"] = task_data["grid_level"].map(
|
||||||
|
{gl: i for i, gl in enumerate(grid_level_order)}
|
||||||
|
)
|
||||||
|
task_data = task_data.sort_values("grid_level_order")
|
||||||
|
|
||||||
|
# Determine which subplot this goes to
|
||||||
|
row = 1
|
||||||
|
col = task_idx + 1 if task is not None else 1
|
||||||
|
|
||||||
|
# Add line trace connecting points of the same model
|
||||||
|
line = go.Scatter(
|
||||||
|
x=task_data["grid_level_display"],
|
||||||
|
y=task_data[metric_col],
|
||||||
|
mode="lines",
|
||||||
|
line={
|
||||||
|
"color": target_colors[target],
|
||||||
|
"width": 1.5,
|
||||||
|
"dash": "dot",
|
||||||
|
},
|
||||||
|
showlegend=False,
|
||||||
|
hoverinfo="skip",
|
||||||
|
xaxis=f"x{col}" if col > 1 else "x",
|
||||||
|
yaxis=f"y{col}" if col > 1 else "y",
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.add_trace(line, row=row, col=col)
|
||||||
|
|
||||||
|
# Add scatter trace on top of the line
|
||||||
|
scatter = go.Scatter(
|
||||||
|
x=task_data["grid_level_display"],
|
||||||
|
y=task_data[metric_col],
|
||||||
|
mode="markers",
|
||||||
|
marker={
|
||||||
|
"symbol": symbol,
|
||||||
|
"size": 8,
|
||||||
|
"color": target_colors[target],
|
||||||
|
"line": {"width": 1, "color": "white"},
|
||||||
|
},
|
||||||
|
name=f"{target} ({model.upper()})",
|
||||||
|
legendgroup=target,
|
||||||
|
showlegend=(task_idx == 0), # Only show in legend once
|
||||||
|
hovertemplate=(
|
||||||
|
f"<b>{target}</b><br>"
|
||||||
|
f"Model: {model.upper()}<br>"
|
||||||
|
f"Grid Level: %{{x}}<br>"
|
||||||
|
f"{format_metric_name(metric)}: %{{y:.4f}}<br>"
|
||||||
|
"<extra></extra>"
|
||||||
|
),
|
||||||
|
xaxis=f"x{col}" if col > 1 else "x",
|
||||||
|
yaxis=f"y{col}" if col > 1 else "y",
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.add_trace(scatter, row=row, col=col)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
height=500,
|
||||||
|
showlegend=True,
|
||||||
|
hovermode="closest",
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_ranking_plot(
|
||||||
|
results_df: pd.DataFrame,
|
||||||
|
metric: str,
|
||||||
|
split: str = "test",
|
||||||
|
top_n: int = 10,
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create a plot ranking models by performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results_df: DataFrame with experiment results
|
||||||
|
metric: Metric to rank by
|
||||||
|
split: Data split to show
|
||||||
|
top_n: Number of top models to show
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly figure showing top models
|
||||||
|
|
||||||
|
"""
|
||||||
|
metric_col = f"{split}_{metric}"
|
||||||
|
|
||||||
|
if metric_col not in results_df.columns:
|
||||||
|
raise ValueError(f"Metric {metric_col} not found in results")
|
||||||
|
|
||||||
|
# Get top N models
|
||||||
|
top_models = results_df.nlargest(top_n, metric_col)
|
||||||
|
|
||||||
|
# Create display label
|
||||||
|
top_models = top_models.copy()
|
||||||
|
top_models["model_label"] = (
|
||||||
|
top_models["model"].astype(str)
|
||||||
|
+ " ("
|
||||||
|
+ top_models["grid"]
|
||||||
|
+ "_"
|
||||||
|
+ top_models["level"].astype(str)
|
||||||
|
+ ", "
|
||||||
|
+ top_models["task"]
|
||||||
|
+ ")"
|
||||||
|
)
|
||||||
|
|
||||||
|
colors = get_palette("models", len(top_models))
|
||||||
|
|
||||||
|
fig = go.Figure(
|
||||||
|
data=[
|
||||||
|
go.Bar(
|
||||||
|
x=top_models[metric_col],
|
||||||
|
y=top_models["model_label"],
|
||||||
|
orientation="h",
|
||||||
|
marker={"color": colors},
|
||||||
|
text=top_models[metric_col].round(4),
|
||||||
|
textposition="auto",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title=f"Top {top_n} Models by {format_metric_name(metric)} ({split.capitalize()} Set)",
|
||||||
|
xaxis_title=format_metric_name(metric),
|
||||||
|
yaxis_title="Model Configuration",
|
||||||
|
height=max(400, top_n * 40),
|
||||||
|
showlegend=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_consistency_heatmap(
|
||||||
|
results_df: pd.DataFrame,
|
||||||
|
metric: str,
|
||||||
|
split: str = "test",
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create a heatmap showing model consistency across tasks and targets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results_df: DataFrame with experiment results
|
||||||
|
metric: Metric to analyze
|
||||||
|
split: Data split to show
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly figure showing heatmap of model performance
|
||||||
|
|
||||||
|
"""
|
||||||
|
metric_col = f"{split}_{metric}"
|
||||||
|
|
||||||
|
if metric_col not in results_df.columns:
|
||||||
|
raise ValueError(f"Metric {metric_col} not found in results")
|
||||||
|
|
||||||
|
# Create a pivot table: models vs (task, target)
|
||||||
|
results_df = results_df.copy()
|
||||||
|
results_df["task_target"] = results_df["task"] + "_" + results_df["target"]
|
||||||
|
|
||||||
|
# Get best score per model-task-target combination
|
||||||
|
pivot_data = results_df.groupby(["model", "task_target"])[metric_col].max().reset_index()
|
||||||
|
pivot_table = pivot_data.pivot_table(index="model", columns="task_target", values=metric_col)
|
||||||
|
|
||||||
|
fig = go.Figure(
|
||||||
|
data=go.Heatmap(
|
||||||
|
z=pivot_table.to_numpy(),
|
||||||
|
x=pivot_table.columns,
|
||||||
|
y=pivot_table.index,
|
||||||
|
colorscale="Viridis",
|
||||||
|
text=pivot_table.to_numpy().round(3),
|
||||||
|
texttemplate="%{text}",
|
||||||
|
textfont={"size": 10},
|
||||||
|
colorbar={"title": format_metric_name(metric)},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title=f"Model Consistency: {format_metric_name(metric)} Across Tasks and Targets ({split.capitalize()} Set)",
|
||||||
|
xaxis_title="Task_Target",
|
||||||
|
yaxis_title="Model",
|
||||||
|
height=max(400, len(pivot_table.index) * 50),
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_feature_importance_comparison_plot(
|
||||||
|
feature_importance_df: pd.DataFrame,
|
||||||
|
top_n: int = 20,
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create a plot comparing feature importance across different models/configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_importance_df: DataFrame with columns: feature, importance, model, grid, level, task
|
||||||
|
top_n: Number of top features to show
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly figure showing feature importance comparison
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Get top features overall
|
||||||
|
overall_importance = feature_importance_df.groupby("feature")["importance"].mean().reset_index()
|
||||||
|
top_features = overall_importance.nlargest(top_n, "importance")["feature"].tolist()
|
||||||
|
|
||||||
|
# Filter to top features
|
||||||
|
filtered_df = feature_importance_df[feature_importance_df["feature"].isin(top_features)]
|
||||||
|
|
||||||
|
# Create grid_level column
|
||||||
|
filtered_df = filtered_df.copy()
|
||||||
|
filtered_df["config"] = (
|
||||||
|
filtered_df["model"].astype(str)
|
||||||
|
+ " ("
|
||||||
|
+ filtered_df["grid"]
|
||||||
|
+ "_"
|
||||||
|
+ filtered_df["level"].astype(str)
|
||||||
|
+ ", "
|
||||||
|
+ filtered_df["task"]
|
||||||
|
+ ")"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create grouped bar chart
|
||||||
|
fig = px.bar(
|
||||||
|
filtered_df,
|
||||||
|
x="feature",
|
||||||
|
y="importance",
|
||||||
|
color="config",
|
||||||
|
barmode="group",
|
||||||
|
title=f"Top {top_n} Features: Importance Comparison Across Configurations",
|
||||||
|
labels={"importance": "Feature Importance", "feature": "Feature", "config": "Configuration"},
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
height=600,
|
||||||
|
xaxis_tickangle=-45,
|
||||||
|
showlegend=True,
|
||||||
|
legend={"orientation": "v", "yanchor": "top", "y": 1, "xanchor": "left", "x": 1.02},
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_feature_importance_heatmap(
|
||||||
|
feature_importance_df: pd.DataFrame,
|
||||||
|
top_n: int = 30,
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create a heatmap of feature importance across configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_importance_df: DataFrame with columns: feature, importance, model, grid, level, task
|
||||||
|
top_n: Number of top features to show
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly figure showing heatmap
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Get top features overall
|
||||||
|
overall_importance = feature_importance_df.groupby("feature")["importance"].mean().reset_index()
|
||||||
|
top_features = overall_importance.nlargest(top_n, "importance")["feature"].tolist()
|
||||||
|
|
||||||
|
# Filter to top features
|
||||||
|
filtered_df = feature_importance_df[feature_importance_df["feature"].isin(top_features)]
|
||||||
|
|
||||||
|
# Create config column
|
||||||
|
filtered_df = filtered_df.copy()
|
||||||
|
filtered_df["config"] = (
|
||||||
|
filtered_df["grid"]
|
||||||
|
+ "_"
|
||||||
|
+ filtered_df["level"].astype(str)
|
||||||
|
+ "_"
|
||||||
|
+ filtered_df["task"]
|
||||||
|
+ "_"
|
||||||
|
+ filtered_df["model"].astype(str)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pivot: features vs configs
|
||||||
|
pivot_table = filtered_df.pivot_table(
|
||||||
|
index="feature", columns="config", values="importance", aggfunc="mean", fill_value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = go.Figure(
|
||||||
|
data=go.Heatmap(
|
||||||
|
z=pivot_table.to_numpy(),
|
||||||
|
x=pivot_table.columns,
|
||||||
|
y=pivot_table.index,
|
||||||
|
colorscale="Viridis",
|
||||||
|
colorbar={"title": "Importance"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title=f"Top {top_n} Features: Importance Heatmap Across Configurations",
|
||||||
|
xaxis_title="Configuration",
|
||||||
|
yaxis_title="Feature",
|
||||||
|
height=max(500, len(pivot_table.index) * 20),
|
||||||
|
xaxis_tickangle=-45,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_performance_distribution(
|
||||||
|
results_df: pd.DataFrame,
|
||||||
|
metric: str,
|
||||||
|
split: str = "test",
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create violin plots showing performance distribution by model type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results_df: DataFrame with experiment results
|
||||||
|
metric: Metric to analyze
|
||||||
|
split: Data split to show
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly figure with violin plots
|
||||||
|
|
||||||
|
"""
|
||||||
|
metric_col = f"{split}_{metric}"
|
||||||
|
|
||||||
|
if metric_col not in results_df.columns:
|
||||||
|
raise ValueError(f"Metric {metric_col} not found in results")
|
||||||
|
|
||||||
|
colors = get_palette("models", results_df["model"].nunique())
|
||||||
|
color_map = {model: colors[i] for i, model in enumerate(results_df["model"].unique())}
|
||||||
|
|
||||||
|
fig = px.violin(
|
||||||
|
results_df,
|
||||||
|
x="model",
|
||||||
|
y=metric_col,
|
||||||
|
color="model",
|
||||||
|
box=True,
|
||||||
|
points="all",
|
||||||
|
title=f"{format_metric_name(metric)} Distribution by Model Type ({split.capitalize()} Set)",
|
||||||
|
labels={metric_col: format_metric_name(metric), "model": "Model"},
|
||||||
|
color_discrete_map=color_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
height=500,
|
||||||
|
showlegend=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_grid_vs_model_performance(
|
||||||
|
results_df: pd.DataFrame,
|
||||||
|
metric: str,
|
||||||
|
split: str = "test",
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create a faceted plot showing model performance across grids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results_df: DataFrame with experiment results
|
||||||
|
metric: Metric to analyze
|
||||||
|
split: Data split to show
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly figure with faceted plots
|
||||||
|
|
||||||
|
"""
|
||||||
|
metric_col = f"{split}_{metric}"
|
||||||
|
|
||||||
|
if metric_col not in results_df.columns:
|
||||||
|
raise ValueError(f"Metric {metric_col} not found in results")
|
||||||
|
|
||||||
|
results_df = results_df.copy()
|
||||||
|
results_df["grid_level"] = results_df["grid"] + "_" + results_df["level"].astype(str)
|
||||||
|
|
||||||
|
fig = px.box(
|
||||||
|
results_df,
|
||||||
|
x="model",
|
||||||
|
y=metric_col,
|
||||||
|
color="model",
|
||||||
|
facet_col="grid_level",
|
||||||
|
facet_row="task",
|
||||||
|
points="all",
|
||||||
|
title=f"{format_metric_name(metric)}: Model Performance Across Grid Levels and Tasks ({split.capitalize()})",
|
||||||
|
labels={metric_col: format_metric_name(metric), "model": "Model"},
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
height=800,
|
||||||
|
showlegend=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_performance_improvement_plot(
|
||||||
|
results_df: pd.DataFrame,
|
||||||
|
metric: str,
|
||||||
|
baseline_model: str = "rf",
|
||||||
|
split: str = "test",
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create a plot showing performance improvement relative to a baseline model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results_df: DataFrame with experiment results
|
||||||
|
metric: Metric to analyze
|
||||||
|
baseline_model: Model to use as baseline for comparison
|
||||||
|
split: Data split to show
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly figure showing improvement over baseline
|
||||||
|
|
||||||
|
"""
|
||||||
|
metric_col = f"{split}_{metric}"
|
||||||
|
|
||||||
|
if metric_col not in results_df.columns:
|
||||||
|
raise ValueError(f"Metric {metric_col} not found in results")
|
||||||
|
|
||||||
|
results_df = results_df.copy()
|
||||||
|
|
||||||
|
# Calculate baseline performance for each task-target-grid-level combination
|
||||||
|
baseline_perf = (
|
||||||
|
results_df[results_df["model"] == baseline_model]
|
||||||
|
.groupby(["task", "target", "grid", "level"])[metric_col]
|
||||||
|
.max()
|
||||||
|
.reset_index()
|
||||||
|
.rename(columns={metric_col: "baseline_score"})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge with all results
|
||||||
|
results_with_baseline = results_df.merge(baseline_perf, on=["task", "target", "grid", "level"], how="left")
|
||||||
|
|
||||||
|
# Calculate improvement
|
||||||
|
results_with_baseline["improvement"] = results_with_baseline[metric_col] - results_with_baseline["baseline_score"]
|
||||||
|
results_with_baseline["improvement_pct"] = (
|
||||||
|
100 * results_with_baseline["improvement"] / results_with_baseline["baseline_score"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter out baseline model itself
|
||||||
|
results_with_baseline = results_with_baseline[results_with_baseline["model"] != baseline_model]
|
||||||
|
|
||||||
|
# Create plot
|
||||||
|
fig = px.box(
|
||||||
|
results_with_baseline,
|
||||||
|
x="model",
|
||||||
|
y="improvement_pct",
|
||||||
|
color="model",
|
||||||
|
points="all",
|
||||||
|
title=(
|
||||||
|
f"Performance Improvement Over {baseline_model.upper()} Baseline "
|
||||||
|
f"({format_metric_name(metric)}, {split.capitalize()} Set)"
|
||||||
|
),
|
||||||
|
labels={"improvement_pct": "Improvement (%)", "model": "Model"},
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.add_hline(y=0, line_dash="dash", line_color="gray", annotation_text="Baseline")
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
height=500,
|
||||||
|
showlegend=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_top_models_bar_chart(
|
||||||
|
task_df: pd.DataFrame,
|
||||||
|
metric: str,
|
||||||
|
task_name: str,
|
||||||
|
split: str = "test",
|
||||||
|
top_n: int = 10,
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create a horizontal bar chart showing top models for a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_df: DataFrame filtered for a specific task
|
||||||
|
metric: Metric to display
|
||||||
|
task_name: Name of the task for the title
|
||||||
|
split: Data split to show
|
||||||
|
top_n: Number of top models to show
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly figure with horizontal bar chart
|
||||||
|
|
||||||
|
"""
|
||||||
|
metric_col = f"{split}_{metric}"
|
||||||
|
|
||||||
|
if metric_col not in task_df.columns:
|
||||||
|
raise ValueError(f"Metric {metric_col} not found in data")
|
||||||
|
|
||||||
|
# Get top N models
|
||||||
|
top_models = task_df.nlargest(top_n, metric_col).copy()
|
||||||
|
|
||||||
|
# Create label combining model, grid level, and target
|
||||||
|
top_models["label"] = (
|
||||||
|
top_models["model"].str.upper()
|
||||||
|
+ " ("
|
||||||
|
+ top_models["grid"]
|
||||||
|
+ "-"
|
||||||
|
+ top_models["level"].astype(str)
|
||||||
|
+ ", "
|
||||||
|
+ top_models["target"]
|
||||||
|
+ ")"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort by score for display (ascending so best is on top)
|
||||||
|
top_models = top_models.sort_values(metric_col, ascending=True)
|
||||||
|
|
||||||
|
# Create color map based on model type
|
||||||
|
unique_models = top_models["model"].unique()
|
||||||
|
model_colors = {model: get_palette("models", len(unique_models))[i] for i, model in enumerate(unique_models)}
|
||||||
|
top_models["color"] = top_models["model"].map(model_colors)
|
||||||
|
|
||||||
|
fig = go.Figure(
|
||||||
|
data=[
|
||||||
|
go.Bar(
|
||||||
|
y=top_models["label"],
|
||||||
|
x=top_models[metric_col],
|
||||||
|
orientation="h",
|
||||||
|
marker={"color": top_models["color"]},
|
||||||
|
text=top_models[metric_col].round(4),
|
||||||
|
textposition="auto",
|
||||||
|
hovertemplate=("<b>%{y}</b><br>" + f"{format_metric_name(metric)}: %{{x:.4f}}<br>" + "<extra></extra>"),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title=f"Top {top_n} Models - {task_name.replace('_', ' ').title()}",
|
||||||
|
xaxis_title=format_metric_name(metric),
|
||||||
|
yaxis_title="",
|
||||||
|
height=max(400, top_n * 40),
|
||||||
|
showlegend=False,
|
||||||
|
margin={"l": 250},
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_feature_importance_by_grid_level(
|
||||||
|
fi_df: pd.DataFrame,
|
||||||
|
top_n: int = 15,
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create a grouped bar chart showing top features by grid level.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fi_df: Feature importance DataFrame with columns: feature, importance, grid_level
|
||||||
|
top_n: Number of top features to show
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly figure with grouped bar chart
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Get overall top features
|
||||||
|
overall_top = fi_df.groupby("feature")["importance"].mean().nlargest(top_n).index.tolist()
|
||||||
|
|
||||||
|
# Filter to top features
|
||||||
|
filtered = fi_df[fi_df["feature"].isin(overall_top)]
|
||||||
|
|
||||||
|
# Calculate mean importance per feature per grid level
|
||||||
|
grouped = filtered.groupby(["grid_level", "feature"])["importance"].mean().reset_index()
|
||||||
|
|
||||||
|
# Create the plot
|
||||||
|
fig = px.bar(
|
||||||
|
grouped,
|
||||||
|
x="feature",
|
||||||
|
y="importance",
|
||||||
|
color="grid_level",
|
||||||
|
barmode="group",
|
||||||
|
title=f"Top {top_n} Features by Grid Level",
|
||||||
|
labels={"importance": "Mean Importance", "feature": "Feature", "grid_level": "Grid Level"},
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
height=500,
|
||||||
|
xaxis_tickangle=-45,
|
||||||
|
showlegend=True,
|
||||||
|
legend={"orientation": "v", "yanchor": "top", "y": 1, "xanchor": "left", "x": 1.02},
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_feature_consistency_plot(
|
||||||
|
fi_df: pd.DataFrame,
|
||||||
|
top_n: int = 15,
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create a scatter plot showing feature importance vs consistency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fi_df: Feature importance DataFrame
|
||||||
|
top_n: Number of top features to show
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly figure with scatter plot
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Get top features
|
||||||
|
overall_top = fi_df.groupby("feature")["importance"].mean().nlargest(top_n).index.tolist()
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
stats = (
|
||||||
|
fi_df[fi_df["feature"].isin(overall_top)].groupby("feature")["importance"].agg(["mean", "std"]).reset_index()
|
||||||
|
)
|
||||||
|
stats["cv"] = stats["std"] / stats["mean"]
|
||||||
|
|
||||||
|
# Add data source if available
|
||||||
|
if "data_source" in fi_df.columns:
|
||||||
|
data_source_map = fi_df[["feature", "data_source"]].drop_duplicates().set_index("feature")["data_source"]
|
||||||
|
stats["data_source"] = stats["feature"].map(data_source_map)
|
||||||
|
color_col = "data_source"
|
||||||
|
else:
|
||||||
|
color_col = None
|
||||||
|
|
||||||
|
fig = px.scatter(
|
||||||
|
stats,
|
||||||
|
x="cv",
|
||||||
|
y="mean",
|
||||||
|
size="std",
|
||||||
|
color=color_col,
|
||||||
|
hover_data=["feature"],
|
||||||
|
text="feature",
|
||||||
|
title="Feature Importance vs Consistency (Coefficient of Variation)",
|
||||||
|
labels={
|
||||||
|
"cv": "Coefficient of Variation (lower = more consistent)",
|
||||||
|
"mean": "Mean Importance",
|
||||||
|
"std": "Std Dev",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_traces(textposition="top center", textfont_size=8)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
height=600,
|
||||||
|
showlegend=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_data_source_importance_bars(
|
||||||
|
fi_df: pd.DataFrame,
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create a bar chart showing importance breakdown by data source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fi_df: Feature importance DataFrame with data_source column
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly figure with bar chart
|
||||||
|
|
||||||
|
"""
|
||||||
|
if "data_source" not in fi_df.columns:
|
||||||
|
raise ValueError("data_source column not found in feature importance data")
|
||||||
|
|
||||||
|
# Aggregate by data source
|
||||||
|
source_stats = (
|
||||||
|
fi_df.groupby("data_source")["importance"]
|
||||||
|
.agg(["sum", "mean", "count"])
|
||||||
|
.reset_index()
|
||||||
|
.sort_values("sum", ascending=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create colors for data sources
|
||||||
|
colors = get_palette("sources", len(source_stats))
|
||||||
|
|
||||||
|
fig = go.Figure(
|
||||||
|
data=[
|
||||||
|
go.Bar(
|
||||||
|
x=source_stats["data_source"],
|
||||||
|
y=source_stats["sum"],
|
||||||
|
marker={"color": colors},
|
||||||
|
text=source_stats["sum"].round(2),
|
||||||
|
textposition="auto",
|
||||||
|
hovertemplate=(
|
||||||
|
"<b>%{x}</b><br>"
|
||||||
|
"Total Importance: %{y:.2f}<br>"
|
||||||
|
"Mean: %{customdata[0]:.4f}<br>"
|
||||||
|
"Features: %{customdata[1]}<br>"
|
||||||
|
"<extra></extra>"
|
||||||
|
),
|
||||||
|
customdata=source_stats[["mean", "count"]].values,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title="Feature Importance by Data Source",
|
||||||
|
xaxis_title="Data Source",
|
||||||
|
yaxis_title="Total Importance",
|
||||||
|
height=400,
|
||||||
|
showlegend=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_inference_maps(
|
||||||
|
inference_gdf: "gpd.GeoDataFrame",
|
||||||
|
grid: str,
|
||||||
|
level: int,
|
||||||
|
task: str,
|
||||||
|
) -> tuple["pdk.Deck", "pdk.Deck"]:
|
||||||
|
"""Create inference maps showing mean and std of predictions using pydeck.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inference_gdf: GeoDataFrame with geometry, mean_prediction, std_prediction columns
|
||||||
|
grid: Grid type (e.g., 'hex', 'healpix')
|
||||||
|
level: Grid resolution level
|
||||||
|
task: Task type for title
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (mean_deck, std_deck) pydeck visualizations
|
||||||
|
|
||||||
|
"""
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
import numpy as np
|
||||||
|
import pydeck as pdk
|
||||||
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
|
|
||||||
|
from entropice.dashboard.utils.colors import hex_to_rgb
|
||||||
|
from entropice.dashboard.utils.geometry import fix_hex_geometry
|
||||||
|
|
||||||
|
# Create a copy and convert to EPSG:4326 for pydeck
|
||||||
|
gdf = inference_gdf.copy().to_crs("EPSG:4326")
|
||||||
|
|
||||||
|
# Fix antimeridian issues for hex cells
|
||||||
|
gdf["geometry"] = gdf["geometry"].apply(fix_hex_geometry)
|
||||||
|
|
||||||
|
# Create custom colormap for predictions (white -> blue -> purple)
|
||||||
|
colors_mean = ["#f7fbff", "#deebf7", "#c6dbef", "#9ecae1", "#6baed6", "#4292c6", "#2171b5", "#08519c", "#08306b"]
|
||||||
|
cmap_mean = LinearSegmentedColormap.from_list("prediction", colors_mean)
|
||||||
|
|
||||||
|
colors_std = ["#f7fcf5", "#e5f5e0", "#c7e9c0", "#a1d99b", "#74c476", "#41ab5d", "#238b45", "#006d2c", "#00441b"]
|
||||||
|
cmap_std = LinearSegmentedColormap.from_list("uncertainty", colors_std)
|
||||||
|
|
||||||
|
# Normalize mean predictions
|
||||||
|
mean_values = gdf["mean_prediction"].to_numpy()
|
||||||
|
mean_min, mean_max = np.nanpercentile(mean_values, [2, 98])
|
||||||
|
if mean_max > mean_min:
|
||||||
|
mean_normalized = np.clip((mean_values - mean_min) / (mean_max - mean_min), 0, 1)
|
||||||
|
else:
|
||||||
|
mean_normalized = np.zeros_like(mean_values)
|
||||||
|
|
||||||
|
# Map normalized values to colors for mean
|
||||||
|
mean_colors = [cmap_mean(val) for val in mean_normalized]
|
||||||
|
mean_rgb_colors = [hex_to_rgb(mcolors.to_hex(color)) for color in mean_colors]
|
||||||
|
gdf["mean_color"] = mean_rgb_colors
|
||||||
|
|
||||||
|
# Normalize std predictions
|
||||||
|
std_values = gdf["std_prediction"].to_numpy()
|
||||||
|
std_min, std_max = np.nanpercentile(std_values, [2, 98])
|
||||||
|
if std_max > std_min:
|
||||||
|
std_normalized = np.clip((std_values - std_min) / (std_max - std_min), 0, 1)
|
||||||
|
else:
|
||||||
|
std_normalized = np.zeros_like(std_values)
|
||||||
|
|
||||||
|
# Map normalized values to colors for std
|
||||||
|
std_colors = [cmap_std(val) for val in std_normalized]
|
||||||
|
std_rgb_colors = [hex_to_rgb(mcolors.to_hex(color)) for color in std_colors]
|
||||||
|
gdf["std_color"] = std_rgb_colors
|
||||||
|
|
||||||
|
# Convert to GeoJSON for mean predictions
|
||||||
|
geojson_mean = []
|
||||||
|
for _, row in gdf.iterrows():
|
||||||
|
feature = {
|
||||||
|
"type": "Feature",
|
||||||
|
"geometry": row["geometry"].__geo_interface__,
|
||||||
|
"properties": {
|
||||||
|
"fill_color": row["mean_color"],
|
||||||
|
"mean_prediction": float(row["mean_prediction"]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
geojson_mean.append(feature)
|
||||||
|
|
||||||
|
# Convert to GeoJSON for std predictions
|
||||||
|
geojson_std = []
|
||||||
|
for _, row in gdf.iterrows():
|
||||||
|
feature = {
|
||||||
|
"type": "Feature",
|
||||||
|
"geometry": row["geometry"].__geo_interface__,
|
||||||
|
"properties": {
|
||||||
|
"fill_color": row["std_color"],
|
||||||
|
"std_prediction": float(row["std_prediction"]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
geojson_std.append(feature)
|
||||||
|
|
||||||
|
# Create pydeck layer for mean predictions
|
||||||
|
layer_mean = pdk.Layer(
|
||||||
|
"GeoJsonLayer",
|
||||||
|
geojson_mean,
|
||||||
|
opacity=0.85,
|
||||||
|
stroked=True,
|
||||||
|
filled=True,
|
||||||
|
extruded=False,
|
||||||
|
wireframe=False,
|
||||||
|
get_fill_color="properties.fill_color",
|
||||||
|
get_line_color=[100, 100, 100],
|
||||||
|
line_width_min_pixels=0.3,
|
||||||
|
pickable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create pydeck layer for std predictions
|
||||||
|
layer_std = pdk.Layer(
|
||||||
|
"GeoJsonLayer",
|
||||||
|
geojson_std,
|
||||||
|
opacity=0.85,
|
||||||
|
stroked=True,
|
||||||
|
filled=True,
|
||||||
|
extruded=False,
|
||||||
|
wireframe=False,
|
||||||
|
get_fill_color="properties.fill_color",
|
||||||
|
get_line_color=[100, 100, 100],
|
||||||
|
line_width_min_pixels=0.3,
|
||||||
|
pickable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set initial view state for Arctic region
|
||||||
|
view_state = pdk.ViewState(
|
||||||
|
latitude=75,
|
||||||
|
longitude=0,
|
||||||
|
zoom=2.5,
|
||||||
|
pitch=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create deck for mean predictions
|
||||||
|
deck_mean = pdk.Deck(
|
||||||
|
layers=[layer_mean],
|
||||||
|
initial_view_state=view_state,
|
||||||
|
tooltip={
|
||||||
|
"html": (
|
||||||
|
f"<b>Mean Prediction:</b> {{mean_prediction}}<br>"
|
||||||
|
f"<b>{grid.upper()}-{level}</b> | <b>Task:</b> {task.title()}"
|
||||||
|
),
|
||||||
|
"style": {"backgroundColor": "#2171b5", "color": "white"},
|
||||||
|
},
|
||||||
|
map_style="https://basemaps.cartocdn.com/gl/positron-gl-style/style.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create deck for std predictions
|
||||||
|
deck_std = pdk.Deck(
|
||||||
|
layers=[layer_std],
|
||||||
|
initial_view_state=view_state,
|
||||||
|
tooltip={
|
||||||
|
"html": (
|
||||||
|
f"<b>Uncertainty (Std):</b> {{std_prediction}}<br>"
|
||||||
|
f"<b>{grid.upper()}-{level}</b> | <b>Task:</b> {task.title()}"
|
||||||
|
),
|
||||||
|
"style": {"backgroundColor": "#238b45", "color": "white"},
|
||||||
|
},
|
||||||
|
map_style="https://basemaps.cartocdn.com/gl/positron-gl-style/style.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
return deck_mean, deck_std
|
||||||
|
|
@ -6,10 +6,74 @@ import plotly.graph_objects as go
|
||||||
import pydeck as pdk
|
import pydeck as pdk
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from entropice.dashboard.utils.class_ordering import get_ordered_classes, sort_class_series
|
|
||||||
from entropice.dashboard.utils.colors import get_palette
|
from entropice.dashboard.utils.colors import get_palette
|
||||||
from entropice.dashboard.utils.geometry import fix_hex_geometry
|
from entropice.dashboard.utils.geometry import fix_hex_geometry
|
||||||
from entropice.dashboard.utils.loaders import TrainingResult
|
from entropice.dashboard.utils.loaders import TrainingResult
|
||||||
|
from entropice.utils.types import Task
|
||||||
|
|
||||||
|
# Canonical orderings imported from the ML pipeline
|
||||||
|
# Binary labels are defined inline in dataset.py: {False: "No RTS", True: "RTS"}
|
||||||
|
# Count/Density labels are defined in the bin_values function
|
||||||
|
BINARY_LABELS = ["No RTS", "RTS"]
|
||||||
|
COUNT_LABELS = ["None", "Very Few", "Few", "Several", "Many", "Very Many"]
|
||||||
|
DENSITY_LABELS = ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"]
|
||||||
|
|
||||||
|
CLASS_ORDERINGS: dict[Task | str, list[str]] = {
|
||||||
|
"binary": BINARY_LABELS,
|
||||||
|
"count_regimes": COUNT_LABELS,
|
||||||
|
"density_regimes": DENSITY_LABELS,
|
||||||
|
# Legacy aliases (deprecated)
|
||||||
|
"count": COUNT_LABELS,
|
||||||
|
"density": DENSITY_LABELS,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_ordered_classes(task: Task | str, available_classes: list[str] | None = None) -> list[str]:
|
||||||
|
"""Get properly ordered class labels for a given task.
|
||||||
|
|
||||||
|
This uses the same canonical ordering as defined in the ML dataset module,
|
||||||
|
ensuring consistency between training and inference visualizations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Task type ('binary', 'count_regimes', 'density_regimes', 'count', 'density').
|
||||||
|
available_classes: Optional list of available classes to filter and order.
|
||||||
|
If None, returns all canonical classes for the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of class labels in proper order.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> get_ordered_classes("binary")
|
||||||
|
['No RTS', 'RTS']
|
||||||
|
>>> get_ordered_classes("count_regimes", ["None", "Few", "Several"])
|
||||||
|
['None', 'Few', 'Several']
|
||||||
|
|
||||||
|
"""
|
||||||
|
canonical_order = CLASS_ORDERINGS[task]
|
||||||
|
|
||||||
|
if available_classes is None:
|
||||||
|
return canonical_order
|
||||||
|
|
||||||
|
# Filter canonical order to only include available classes, preserving order
|
||||||
|
return [cls for cls in canonical_order if cls in available_classes]
|
||||||
|
|
||||||
|
|
||||||
|
def sort_class_series(series: pd.Series, task: Task | str) -> pd.Series:
|
||||||
|
"""Sort a pandas Series with class labels according to canonical ordering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
series: Pandas Series with class labels as index.
|
||||||
|
task: Task type ('binary', 'count_regimes', 'density_regimes', 'count', 'density').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted Series with classes in canonical order.
|
||||||
|
|
||||||
|
"""
|
||||||
|
available_classes = series.index.tolist()
|
||||||
|
ordered_classes = get_ordered_classes(task, available_classes)
|
||||||
|
|
||||||
|
# Reindex to get proper order
|
||||||
|
return series.reindex(ordered_classes)
|
||||||
|
|
||||||
|
|
||||||
def render_inference_statistics(predictions_gdf: gpd.GeoDataFrame, task: str):
|
def render_inference_statistics(predictions_gdf: gpd.GeoDataFrame, task: str):
|
||||||
|
|
@ -60,7 +124,7 @@ def render_class_distribution_histogram(predictions_gdf: gpd.GeoDataFrame, task:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
predictions_gdf: GeoDataFrame with predictions.
|
predictions_gdf: GeoDataFrame with predictions.
|
||||||
task: Task type ('binary', 'count', 'density').
|
task: Task type ('binary', 'count_regimes', 'density_regimes', 'count', 'density').
|
||||||
|
|
||||||
"""
|
"""
|
||||||
st.subheader("📊 Predicted Class Distribution")
|
st.subheader("📊 Predicted Class Distribution")
|
||||||
|
|
@ -348,7 +412,7 @@ def render_class_comparison(predictions_gdf: gpd.GeoDataFrame, task: str):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
predictions_gdf: GeoDataFrame with predictions.
|
predictions_gdf: GeoDataFrame with predictions.
|
||||||
task: Task type ('binary', 'count', 'density').
|
task: Task type ('binary', 'count_regimes', 'density_regimes', 'count', 'density').
|
||||||
|
|
||||||
"""
|
"""
|
||||||
st.subheader("🔍 Class Comparison")
|
st.subheader("🔍 Class Comparison")
|
||||||
|
|
|
||||||
0
src/entropice/dashboard/sections/__init__.py
Normal file
0
src/entropice/dashboard/sections/__init__.py
Normal file
674
src/entropice/dashboard/sections/correlations.py
Normal file
674
src/entropice/dashboard/sections/correlations.py
Normal file
|
|
@ -0,0 +1,674 @@
|
||||||
|
"""Cross-dataset correlation analysis dashboard section."""
|
||||||
|
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import streamlit as st
|
||||||
|
import xarray as xr
|
||||||
|
|
||||||
|
from entropice.dashboard.plots.correlations import (
|
||||||
|
create_dendrogram_plot,
|
||||||
|
create_feature_variance_plot,
|
||||||
|
create_full_correlation_heatmap,
|
||||||
|
create_matplotlib_correlation_heatmap,
|
||||||
|
create_mutual_information_matrix,
|
||||||
|
create_pca_biplot,
|
||||||
|
select_top_variable_features,
|
||||||
|
)
|
||||||
|
from entropice.utils.types import L2SourceDataset
|
||||||
|
|
||||||
|
|
||||||
|
def _get_aggregation_dimensions(ds: xr.Dataset) -> list[str]:
|
||||||
|
"""Get aggregation dimension names from a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds: Xarray Dataset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of aggregation dimension names
|
||||||
|
|
||||||
|
"""
|
||||||
|
return [str(dim) for dim in ds.dims if dim in ("agg", "aggregations")]
|
||||||
|
|
||||||
|
|
||||||
|
def _set_all_aggregations_corr(
|
||||||
|
member_datasets: dict[L2SourceDataset, xr.Dataset],
|
||||||
|
members_with_aggs: list[L2SourceDataset],
|
||||||
|
member_agg_dims: dict[L2SourceDataset, list[str]],
|
||||||
|
selected: bool,
|
||||||
|
):
|
||||||
|
"""Set all aggregation checkboxes to selected or deselected state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
member_datasets: Dictionary mapping members to their loaded datasets
|
||||||
|
members_with_aggs: List of members that have aggregations
|
||||||
|
member_agg_dims: Dictionary mapping members to aggregation dimensions
|
||||||
|
selected: True to select all, False to deselect all
|
||||||
|
|
||||||
|
"""
|
||||||
|
for member in members_with_aggs:
|
||||||
|
ds = member_datasets.get(member)
|
||||||
|
if ds is None:
|
||||||
|
continue
|
||||||
|
agg_dims = member_agg_dims.get(member, [])
|
||||||
|
for agg_dim in agg_dims:
|
||||||
|
if agg_dim in ds.dims:
|
||||||
|
agg_values = ds.coords[agg_dim].to_numpy().tolist()
|
||||||
|
for val in agg_values:
|
||||||
|
st.session_state[f"corr_agg_{member}_{agg_dim}_{val}"] = selected
|
||||||
|
|
||||||
|
|
||||||
|
def _set_median_only_aggregations_corr(
|
||||||
|
member_datasets: dict[L2SourceDataset, xr.Dataset],
|
||||||
|
members_with_aggs: list[L2SourceDataset],
|
||||||
|
member_agg_dims: dict[L2SourceDataset, list[str]],
|
||||||
|
):
|
||||||
|
"""Set only median aggregations to selected, deselect all others.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
member_datasets: Dictionary mapping members to their loaded datasets
|
||||||
|
members_with_aggs: List of members that have aggregations
|
||||||
|
member_agg_dims: Dictionary mapping members to aggregation dimensions
|
||||||
|
|
||||||
|
"""
|
||||||
|
for member in members_with_aggs:
|
||||||
|
ds = member_datasets.get(member)
|
||||||
|
if ds is None:
|
||||||
|
continue
|
||||||
|
agg_dims = member_agg_dims.get(member, [])
|
||||||
|
for agg_dim in agg_dims:
|
||||||
|
if agg_dim in ds.dims:
|
||||||
|
agg_values = ds.coords[agg_dim].to_numpy().tolist()
|
||||||
|
for val in agg_values:
|
||||||
|
# Select only if value is or contains 'median'
|
||||||
|
is_median = str(val).lower() == "median" or "median" in str(val).lower()
|
||||||
|
st.session_state[f"corr_agg_{member}_{agg_dim}_{val}"] = is_median
|
||||||
|
|
||||||
|
|
||||||
|
def _render_member_aggregation_checkboxes(
|
||||||
|
member: L2SourceDataset,
|
||||||
|
ds: xr.Dataset,
|
||||||
|
agg_dims: list[str],
|
||||||
|
column: Any,
|
||||||
|
dimension_filters: dict[str, dict[str, list[str]]],
|
||||||
|
) -> dict[str, dict[str, list[str]]]:
|
||||||
|
"""Render checkboxes for a single member's aggregations."""
|
||||||
|
with column:
|
||||||
|
st.markdown(f"**{member}:**")
|
||||||
|
|
||||||
|
for agg_dim in agg_dims:
|
||||||
|
if agg_dim not in ds.dims:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get coordinate values
|
||||||
|
agg_values = [str(v) for v in ds.coords[agg_dim].to_numpy().tolist()]
|
||||||
|
|
||||||
|
st.markdown(f"*{agg_dim}:*")
|
||||||
|
selected_vals = []
|
||||||
|
|
||||||
|
for val in agg_values:
|
||||||
|
key = f"corr_agg_{member}_{agg_dim}_{val}"
|
||||||
|
|
||||||
|
if st.checkbox(val, value=True, key=key, help=f"Include {val} from {agg_dim}"):
|
||||||
|
selected_vals.append(val)
|
||||||
|
|
||||||
|
# Store selected values if not all selected
|
||||||
|
if selected_vals and len(selected_vals) < len(agg_values):
|
||||||
|
if member not in dimension_filters:
|
||||||
|
dimension_filters[member] = {}
|
||||||
|
dimension_filters[member][agg_dim] = selected_vals
|
||||||
|
|
||||||
|
return dimension_filters
|
||||||
|
|
||||||
|
|
||||||
|
def _render_aggregation_selection(
|
||||||
|
member_datasets: dict[L2SourceDataset, xr.Dataset],
|
||||||
|
) -> dict[str, dict[str, list[str]]]:
|
||||||
|
"""Render aggregation selection controls for members that have aggregations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
member_datasets: Dictionary mapping member names to their xarray Datasets
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping member names to dimension filters
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Find members with aggregations
|
||||||
|
members_with_aggs = []
|
||||||
|
member_agg_dims = {}
|
||||||
|
|
||||||
|
for member, ds in member_datasets.items():
|
||||||
|
agg_dims = _get_aggregation_dimensions(ds)
|
||||||
|
if agg_dims:
|
||||||
|
members_with_aggs.append(member)
|
||||||
|
member_agg_dims[member] = agg_dims
|
||||||
|
|
||||||
|
if not members_with_aggs:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
st.markdown("#### Select Aggregations")
|
||||||
|
st.markdown(
|
||||||
|
"Select which spatial aggregations to include. "
|
||||||
|
"This allows you to analyze correlations for specific aggregation types (e.g., mean, median, std)."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add buttons for common selections (outside form to allow state manipulation)
|
||||||
|
col_btn1, col_btn2, col_btn3, _ = st.columns([1, 1, 1, 3])
|
||||||
|
|
||||||
|
with col_btn1:
|
||||||
|
if st.button("✅ Select All", use_container_width=True, key="corr_agg_select_all"):
|
||||||
|
_set_all_aggregations_corr(member_datasets, members_with_aggs, member_agg_dims, selected=True)
|
||||||
|
with col_btn2:
|
||||||
|
if st.button("📊 Median Only", use_container_width=True, key="corr_agg_median_only"):
|
||||||
|
_set_median_only_aggregations_corr(member_datasets, members_with_aggs, member_agg_dims)
|
||||||
|
with col_btn3:
|
||||||
|
if st.button("❌ Deselect All", use_container_width=True, key="corr_agg_deselect_all"):
|
||||||
|
_set_all_aggregations_corr(member_datasets, members_with_aggs, member_agg_dims, selected=False)
|
||||||
|
|
||||||
|
# Render form with checkboxes
|
||||||
|
with st.form("correlation_aggregation_selection_form"):
|
||||||
|
dimension_filters: dict[str, dict[str, list[str]]] = {}
|
||||||
|
|
||||||
|
# Create columns for each member
|
||||||
|
member_cols = st.columns(len(members_with_aggs))
|
||||||
|
|
||||||
|
for col_idx, member in enumerate(members_with_aggs):
|
||||||
|
dimension_filters = _render_member_aggregation_checkboxes(
|
||||||
|
member,
|
||||||
|
member_datasets[member],
|
||||||
|
member_agg_dims[member],
|
||||||
|
member_cols[col_idx],
|
||||||
|
dimension_filters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Submit button for the form
|
||||||
|
submitted = st.form_submit_button("Apply Aggregation Filters", type="primary", use_container_width=True)
|
||||||
|
|
||||||
|
if not submitted:
|
||||||
|
st.info("👆 Click 'Apply Aggregation Filters' to update the configuration")
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
return dimension_filters
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_dataset_to_series(
|
||||||
|
ds: xr.Dataset,
|
||||||
|
prefix: str = "",
|
||||||
|
dimension_filter: dict[str, list[str]] | None = None,
|
||||||
|
) -> dict[str, pd.Series]:
|
||||||
|
"""Flatten an xarray Dataset into a dict of pandas Series.
|
||||||
|
|
||||||
|
Handles multi-dimensional variables by creating separate series for each combination
|
||||||
|
of non-cell_ids dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds: Xarray Dataset to flatten
|
||||||
|
prefix: Prefix to add to variable names
|
||||||
|
dimension_filter: Optional dict mapping dimension names to lists of allowed values.
|
||||||
|
Only combinations matching these values will be included.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping variable names to pandas Series with cell_ids as index
|
||||||
|
|
||||||
|
"""
|
||||||
|
series_dict = {}
|
||||||
|
dimension_filter = dimension_filter or {}
|
||||||
|
|
||||||
|
for var_name in ds.data_vars:
|
||||||
|
var_data = ds[var_name]
|
||||||
|
|
||||||
|
# Get dimensions other than cell_ids
|
||||||
|
other_dims = [dim for dim in var_data.dims if dim != "cell_ids"]
|
||||||
|
|
||||||
|
if len(other_dims) == 0:
|
||||||
|
# Simple 1D variable
|
||||||
|
series = var_data.to_series()
|
||||||
|
full_name = f"{prefix}{var_name}" if prefix else var_name
|
||||||
|
series_dict[full_name] = series
|
||||||
|
else:
|
||||||
|
# Multi-dimensional variable - create series for each combination
|
||||||
|
for coord_values in var_data.stack(stacked=other_dims).coords["stacked"].to_numpy(): # noqa: PD013
|
||||||
|
# Create selector dict
|
||||||
|
if isinstance(coord_values, tuple):
|
||||||
|
selector = dict(zip(other_dims, coord_values))
|
||||||
|
else:
|
||||||
|
selector = {other_dims[0]: coord_values}
|
||||||
|
|
||||||
|
# Check if this combination passes the dimension filter
|
||||||
|
if dimension_filter:
|
||||||
|
skip = False
|
||||||
|
for dim, val in selector.items():
|
||||||
|
if dim in dimension_filter:
|
||||||
|
# Check if value is in allowed list (convert to string for comparison)
|
||||||
|
if str(val) not in dimension_filter[dim]:
|
||||||
|
skip = True
|
||||||
|
break
|
||||||
|
if skip:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract 1D series
|
||||||
|
series = var_data.sel(selector).to_series()
|
||||||
|
|
||||||
|
# Create descriptive name
|
||||||
|
suffix = "_".join(str(v) for v in (coord_values if isinstance(coord_values, tuple) else [coord_values]))
|
||||||
|
full_name = f"{prefix}{var_name}_{suffix}" if prefix else f"{var_name}_{suffix}"
|
||||||
|
series_dict[full_name] = series
|
||||||
|
|
||||||
|
return series_dict
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def _render_correlation_matrix(data_dict: dict[str, pd.Series]):
|
||||||
|
"""Render correlation matrix visualization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict: Dictionary mapping variable names to pandas Series
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.subheader("Correlation Matrix")
|
||||||
|
|
||||||
|
n_features = len(data_dict)
|
||||||
|
|
||||||
|
# Add feature reduction options for large datasets
|
||||||
|
reduced_data = data_dict
|
||||||
|
if n_features > 500:
|
||||||
|
st.warning(
|
||||||
|
f"⚠️ **Large dataset detected:** {n_features} features may be too many to visualize effectively. "
|
||||||
|
"Consider reducing the feature set using the options below."
|
||||||
|
)
|
||||||
|
|
||||||
|
with st.expander("🎯 Feature Reduction Options", expanded=True):
|
||||||
|
reduction_method = st.radio(
|
||||||
|
"Reduction Strategy",
|
||||||
|
options=["top_variable", "none"],
|
||||||
|
format_func=lambda x: (
|
||||||
|
"Select Most Variable Features" if x == "top_variable" else "Use All Features (may be slow)"
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
key="corr_reduction_method",
|
||||||
|
)
|
||||||
|
|
||||||
|
if reduction_method == "top_variable":
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
with col1:
|
||||||
|
n_top = st.slider(
|
||||||
|
"Number of Features to Keep",
|
||||||
|
min_value=50,
|
||||||
|
max_value=min(1000, n_features),
|
||||||
|
value=min(500, n_features),
|
||||||
|
step=50,
|
||||||
|
key="corr_n_top",
|
||||||
|
)
|
||||||
|
with col2:
|
||||||
|
variability_metric = st.selectbox(
|
||||||
|
"Variability Metric",
|
||||||
|
options=["variance", "iqr", "cv"],
|
||||||
|
format_func=lambda x: {
|
||||||
|
"variance": "Variance",
|
||||||
|
"iqr": "Interquartile Range",
|
||||||
|
"cv": "Coefficient of Variation",
|
||||||
|
}[x],
|
||||||
|
key="corr_variability_metric",
|
||||||
|
)
|
||||||
|
|
||||||
|
reduced_data = select_top_variable_features(data_dict, n_features=n_top, method=variability_metric)
|
||||||
|
st.success(f"✅ Reduced from {n_features} to {len(reduced_data)} most variable features")
|
||||||
|
|
||||||
|
n_reduced = len(reduced_data)
|
||||||
|
|
||||||
|
# Configuration options
|
||||||
|
col1, col2 = st.columns([3, 1])
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
method = st.selectbox(
|
||||||
|
"Correlation Method",
|
||||||
|
options=["pearson", "spearman", "kendall"],
|
||||||
|
index=0,
|
||||||
|
format_func=lambda x: x.title(),
|
||||||
|
help="Pearson: linear relationships | Spearman: monotonic relationships | Kendall: robust to outliers",
|
||||||
|
key="corr_method",
|
||||||
|
)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
cluster = cast(
|
||||||
|
bool,
|
||||||
|
st.toggle(
|
||||||
|
"Cluster Variables",
|
||||||
|
value=n_reduced < 100,
|
||||||
|
help="Reorder by hierarchical clustering",
|
||||||
|
key="corr_cluster",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if subsampling will occur
|
||||||
|
df_test = pd.DataFrame(reduced_data).dropna()
|
||||||
|
if len(df_test) > 50000:
|
||||||
|
st.info(
|
||||||
|
f"📊 **Dataset subsampled:** Using 50,000 randomly selected cells out of {len(df_test):,} "
|
||||||
|
"for performance. Correlations remain representative."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use matplotlib for large feature sets (> 200), plotly for smaller
|
||||||
|
use_matplotlib = n_reduced > 200
|
||||||
|
|
||||||
|
if use_matplotlib:
|
||||||
|
st.info(
|
||||||
|
f"📈 **Using Matplotlib rendering** for {n_reduced} features "
|
||||||
|
"(Plotly would exceed Streamlit's message size limit). "
|
||||||
|
"This produces a static image but handles large datasets efficiently."
|
||||||
|
)
|
||||||
|
fig = create_matplotlib_correlation_heatmap(reduced_data, method=method, cluster=cluster)
|
||||||
|
st.pyplot(fig, use_container_width=True)
|
||||||
|
# Close figure to free memory
|
||||||
|
import matplotlib.pyplot as plt_cleanup
|
||||||
|
|
||||||
|
plt_cleanup.close(fig)
|
||||||
|
else:
|
||||||
|
fig = create_full_correlation_heatmap(reduced_data, method=method, cluster=cluster)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
f"""
|
||||||
|
**{method.title()} correlation** measures the {"linear" if method == "pearson" else "monotonic"}
|
||||||
|
relationship between variables. Values range from -1 (perfect negative correlation) to
|
||||||
|
+1 (perfect positive correlation), with 0 indicating no correlation.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def _render_pca_analysis(data_dict: dict[str, pd.Series]):
|
||||||
|
"""Render PCA biplot visualization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict: Dictionary mapping variable names to pandas Series
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.subheader("Principal Component Analysis (PCA)")
|
||||||
|
|
||||||
|
n_features = len(data_dict)
|
||||||
|
|
||||||
|
# Reduce features if too many for visualization
|
||||||
|
reduced_data = data_dict
|
||||||
|
if n_features > 500:
|
||||||
|
st.warning(
|
||||||
|
f"⚠️ **Large dataset:** {n_features} features detected. "
|
||||||
|
"Automatically selecting the 500 most variable features for PCA visualization."
|
||||||
|
)
|
||||||
|
reduced_data = select_top_variable_features(data_dict, n_features=500, method="variance")
|
||||||
|
|
||||||
|
show_loadings = cast(
|
||||||
|
bool,
|
||||||
|
st.toggle(
|
||||||
|
"Show Loading Vectors",
|
||||||
|
value=True,
|
||||||
|
help="Display how each variable contributes to the principal components",
|
||||||
|
key="pca_loadings",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = create_pca_biplot(reduced_data, n_components=2, show_loadings=show_loadings)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
**PCA** reduces the dimensionality of the data while preserving variance.
|
||||||
|
- **Blue points**: Individual grid cells projected onto the first two principal components
|
||||||
|
- **Red arrows**: Loading vectors showing how each variable contributes to the PCs
|
||||||
|
- Variables pointing in similar directions are positively correlated
|
||||||
|
- Longer arrows indicate variables with higher variance
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def _render_hierarchical_clustering(data_dict: dict[str, pd.Series]):
|
||||||
|
"""Render hierarchical clustering dendrogram.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict: Dictionary mapping variable names to pandas Series
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.subheader("Hierarchical Clustering of Variables")
|
||||||
|
|
||||||
|
n_features = len(data_dict)
|
||||||
|
|
||||||
|
# Reduce features if too many
|
||||||
|
reduced_data = data_dict
|
||||||
|
if n_features > 300:
|
||||||
|
st.warning(
|
||||||
|
f"⚠️ **Large dataset:** {n_features} features detected. "
|
||||||
|
"Automatically selecting the 300 most variable features for dendrogram visualization."
|
||||||
|
)
|
||||||
|
reduced_data = select_top_variable_features(data_dict, n_features=300, method="variance")
|
||||||
|
|
||||||
|
method = st.selectbox(
|
||||||
|
"Distance Metric (based on correlation)",
|
||||||
|
options=["pearson", "spearman"],
|
||||||
|
index=0,
|
||||||
|
format_func=lambda x: x.title(),
|
||||||
|
key="dendro_method",
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = create_dendrogram_plot(reduced_data, method=method)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
**Dendrogram** shows hierarchical relationships between variables based on correlation distance.
|
||||||
|
- Variables that merge at lower heights are more similar (highly correlated)
|
||||||
|
- Distinct clusters suggest groups of related features
|
||||||
|
- Useful for identifying redundant features or feature groups
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def _render_mutual_information(data_dict: dict[str, pd.Series]):
|
||||||
|
"""Render mutual information matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict: Dictionary mapping variable names to pandas Series
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.subheader("Mutual Information Analysis")
|
||||||
|
|
||||||
|
n_features = len(data_dict)
|
||||||
|
|
||||||
|
# Mutual information is very computationally expensive - hard limit
|
||||||
|
if n_features > 200:
|
||||||
|
st.error(
|
||||||
|
f"❌ **Too many features:** Mutual information analysis with {n_features} features "
|
||||||
|
"would be too computationally expensive. Please reduce to ≤200 features using the "
|
||||||
|
"variable selection or aggregation filters."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
st.warning(
|
||||||
|
"⚠️ **Computationally intensive**: This analysis may take some time. "
|
||||||
|
"Dataset is subsampled to 10,000 cells for performance."
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = create_mutual_information_matrix(data_dict)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
**Mutual Information** captures both linear and non-linear relationships between variables.
|
||||||
|
- Unlike correlation, MI can detect complex dependencies
|
||||||
|
- Higher values indicate stronger information sharing
|
||||||
|
- Particularly useful for identifying non-linear relationships that correlation might miss
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def _render_variance_analysis(data_dict: dict[str, pd.Series]):
|
||||||
|
"""Render feature variance and spread analysis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict: Dictionary mapping variable names to pandas Series
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.subheader("Feature Variance and Spread")
|
||||||
|
|
||||||
|
fig = create_feature_variance_plot(data_dict)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
**Variance metrics** show the spread and variability of each feature:
|
||||||
|
- **Standard Deviation**: Average distance from the mean
|
||||||
|
- **IQR** (Interquartile Range): Spread of the middle 50% of data (robust to outliers)
|
||||||
|
- **Range**: Difference between max and min values
|
||||||
|
- **CV** (Coefficient of Variation): Normalized variability (std/mean)
|
||||||
|
|
||||||
|
Features with very low variance may provide little information for modeling.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_variable_groups(all_series: dict[str, pd.Series]) -> dict[str, list[str]]:
|
||||||
|
"""Group variables by data source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_series: Dictionary of all variable series
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping source names to lists of variable names
|
||||||
|
|
||||||
|
"""
|
||||||
|
var_groups = {}
|
||||||
|
for var_name in all_series.keys():
|
||||||
|
source = var_name.split("_")[0]
|
||||||
|
if source not in var_groups:
|
||||||
|
var_groups[source] = []
|
||||||
|
var_groups[source].append(var_name)
|
||||||
|
return var_groups
|
||||||
|
|
||||||
|
|
||||||
|
def _render_variable_selection(all_series: dict[str, pd.Series]) -> dict[str, pd.Series]:
|
||||||
|
"""Render variable selection UI and return filtered data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_series: Dictionary of all variable series
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered dictionary with only selected variables
|
||||||
|
|
||||||
|
"""
|
||||||
|
var_groups = _get_variable_groups(all_series)
|
||||||
|
selected_vars = []
|
||||||
|
|
||||||
|
# Create checkboxes for each group
|
||||||
|
for source, vars_in_group in sorted(var_groups.items()):
|
||||||
|
with st.container():
|
||||||
|
col1, col2 = st.columns([1, 4])
|
||||||
|
with col1:
|
||||||
|
select_all = st.checkbox(f"**{source}** ({len(vars_in_group)})", value=True, key=f"select_{source}")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
if select_all:
|
||||||
|
selected_vars.extend(vars_in_group)
|
||||||
|
else:
|
||||||
|
# Show first few variables as examples
|
||||||
|
st.caption(", ".join(vars_in_group[:3]) + ("..." if len(vars_in_group) > 3 else ""))
|
||||||
|
|
||||||
|
# Filter data dict
|
||||||
|
filtered_data = {k: v for k, v in all_series.items() if k in selected_vars}
|
||||||
|
|
||||||
|
if len(filtered_data) < 2:
|
||||||
|
st.warning("⚠️ Please select at least 2 variables for correlation analysis")
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
st.info(f"**Selected {len(filtered_data)} variables** for analysis")
|
||||||
|
return filtered_data
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_correlations_tab(
|
||||||
|
member_datasets: dict[L2SourceDataset, xr.Dataset],
|
||||||
|
grid_area_series: pd.Series | None = None,
|
||||||
|
):
|
||||||
|
"""Render the cross-dataset correlation analysis tab.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
member_datasets: Dictionary mapping member names to their xarray Datasets
|
||||||
|
grid_area_series: Optional pandas Series with cell_ids as index and area values
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
This section provides comprehensive analysis of relationships and similarities
|
||||||
|
between all variables across different data sources (ArcticDEM, ERA5, AlphaEarth, etc.).
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flatten all datasets into a single dict of pandas Series
|
||||||
|
with st.spinner("Preparing data for correlation analysis..."):
|
||||||
|
# First, compute all datasets if needed
|
||||||
|
computed_datasets = {}
|
||||||
|
for member_name, ds in member_datasets.items():
|
||||||
|
# Compute if lazy
|
||||||
|
if any(isinstance(v.data, type(ds)) for v in ds.data_vars.values()):
|
||||||
|
with st.spinner(f"Loading {member_name} data..."):
|
||||||
|
ds = ds.compute()
|
||||||
|
computed_datasets[member_name] = ds
|
||||||
|
|
||||||
|
# Render aggregation selection
|
||||||
|
with st.expander("🔧 Aggregation Selection", expanded=False):
|
||||||
|
dimension_filters = _render_aggregation_selection(computed_datasets)
|
||||||
|
|
||||||
|
if dimension_filters:
|
||||||
|
st.info(
|
||||||
|
f"**Filtering applied:** {sum(len(dims) for dims in dimension_filters.values())} "
|
||||||
|
f"dimension filter(s) across {len(dimension_filters)} data source(s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now flatten with filters applied
|
||||||
|
with st.spinner("Flattening datasets..."):
|
||||||
|
all_series = {}
|
||||||
|
|
||||||
|
# Add grid area if provided
|
||||||
|
if grid_area_series is not None:
|
||||||
|
all_series["Grid_Cell_Area_km2"] = grid_area_series
|
||||||
|
|
||||||
|
# Process each member dataset with its dimension filter
|
||||||
|
for member_name, ds in computed_datasets.items():
|
||||||
|
prefix = f"{member_name}_"
|
||||||
|
member_filter = dimension_filters.get(member_name, None)
|
||||||
|
member_series = _flatten_dataset_to_series(ds, prefix=prefix, dimension_filter=member_filter)
|
||||||
|
all_series.update(member_series)
|
||||||
|
|
||||||
|
st.success(f"✅ Loaded {len(all_series)} variables from {len(member_datasets)} data sources")
|
||||||
|
|
||||||
|
# Variable selection
|
||||||
|
with st.expander("🔧 Variable Selection", expanded=False):
|
||||||
|
st.markdown("Select which variables to include in the correlation analysis.")
|
||||||
|
filtered_data = _render_variable_selection(all_series)
|
||||||
|
|
||||||
|
# Create tabs for different analyses
|
||||||
|
analysis_tabs = st.tabs(
|
||||||
|
[
|
||||||
|
"📊 Correlation Matrix",
|
||||||
|
"🔬 PCA Biplot",
|
||||||
|
"🌳 Hierarchical Clustering",
|
||||||
|
"📈 Variance Analysis",
|
||||||
|
"🔗 Mutual Information",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with analysis_tabs[0]:
|
||||||
|
_render_correlation_matrix(filtered_data)
|
||||||
|
|
||||||
|
with analysis_tabs[1]:
|
||||||
|
_render_pca_analysis(filtered_data)
|
||||||
|
|
||||||
|
with analysis_tabs[2]:
|
||||||
|
_render_hierarchical_clustering(filtered_data)
|
||||||
|
|
||||||
|
with analysis_tabs[3]:
|
||||||
|
_render_variance_analysis(filtered_data)
|
||||||
|
|
||||||
|
with analysis_tabs[4]:
|
||||||
|
_render_mutual_information(filtered_data)
|
||||||
|
|
@ -0,0 +1,331 @@
|
||||||
|
"""Feature importance analysis section for experiment comparison."""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.plots.experiment_comparison import (
|
||||||
|
create_data_source_importance_bars,
|
||||||
|
create_feature_consistency_plot,
|
||||||
|
create_feature_importance_by_grid_level,
|
||||||
|
)
|
||||||
|
from entropice.dashboard.utils.loaders import (
|
||||||
|
AutogluonTrainingResult,
|
||||||
|
TrainingResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_feature_importance_from_results(
|
||||||
|
training_results: list[TrainingResult],
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Extract feature importance from all training results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
training_results: List of TrainingResult objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with columns: feature, importance, model, grid, level, task, target
|
||||||
|
|
||||||
|
"""
|
||||||
|
records = []
|
||||||
|
|
||||||
|
for tr in training_results:
|
||||||
|
# Load model state if available
|
||||||
|
model_state = tr.load_model_state()
|
||||||
|
if model_state is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
info = tr.display_info
|
||||||
|
|
||||||
|
# Extract feature importance based on available data
|
||||||
|
if "feature_importance" in model_state.data_vars:
|
||||||
|
# eSPA or similar models with direct feature importance
|
||||||
|
importance_data = model_state["feature_importance"]
|
||||||
|
for feature_idx, feature_name in enumerate(importance_data.coords["feature"].values):
|
||||||
|
importance_value = float(importance_data.isel(feature=feature_idx).values)
|
||||||
|
records.append(
|
||||||
|
{
|
||||||
|
"feature": str(feature_name),
|
||||||
|
"importance": importance_value,
|
||||||
|
"model": info.model,
|
||||||
|
"grid": info.grid,
|
||||||
|
"level": info.level,
|
||||||
|
"task": info.task,
|
||||||
|
"target": info.target,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif "gain" in model_state.data_vars:
|
||||||
|
# XGBoost-style feature importance
|
||||||
|
gain_data = model_state["gain"]
|
||||||
|
for feature_idx, feature_name in enumerate(gain_data.coords["feature"].values):
|
||||||
|
importance_value = float(gain_data.isel(feature=feature_idx).values)
|
||||||
|
records.append(
|
||||||
|
{
|
||||||
|
"feature": str(feature_name),
|
||||||
|
"importance": importance_value,
|
||||||
|
"model": info.model,
|
||||||
|
"grid": info.grid,
|
||||||
|
"level": info.level,
|
||||||
|
"task": info.task,
|
||||||
|
"target": info.target,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif "feature_importances_" in model_state.data_vars:
|
||||||
|
# Random Forest style
|
||||||
|
importance_data = model_state["feature_importances_"]
|
||||||
|
for feature_idx, feature_name in enumerate(importance_data.coords["feature"].values):
|
||||||
|
importance_value = float(importance_data.isel(feature=feature_idx).values)
|
||||||
|
records.append(
|
||||||
|
{
|
||||||
|
"feature": str(feature_name),
|
||||||
|
"importance": importance_value,
|
||||||
|
"model": info.model,
|
||||||
|
"grid": info.grid,
|
||||||
|
"level": info.level,
|
||||||
|
"task": info.task,
|
||||||
|
"target": info.target,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return pd.DataFrame(records)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_feature_importance_from_autogluon(
|
||||||
|
autogluon_results: list[AutogluonTrainingResult],
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Extract feature importance from AutoGluon results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
autogluon_results: List of AutogluonTrainingResult objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with columns: feature, importance, model, grid, level, task, target
|
||||||
|
|
||||||
|
"""
|
||||||
|
records = []
|
||||||
|
|
||||||
|
for ag in autogluon_results:
|
||||||
|
if ag.feature_importance is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
info = ag.display_info
|
||||||
|
|
||||||
|
# AutoGluon feature importance is already a DataFrame with features as index
|
||||||
|
for feature_name, importance_value in ag.feature_importance["importance"].items():
|
||||||
|
records.append(
|
||||||
|
{
|
||||||
|
"feature": str(feature_name),
|
||||||
|
"importance": float(importance_value),
|
||||||
|
"model": "autogluon",
|
||||||
|
"grid": info.grid,
|
||||||
|
"level": info.level,
|
||||||
|
"task": info.task,
|
||||||
|
"target": info.target,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return pd.DataFrame(records)
|
||||||
|
|
||||||
|
|
||||||
|
def _categorize_feature(feature_name: str) -> str:
|
||||||
|
"""Categorize feature by data source."""
|
||||||
|
feature_lower = feature_name.lower()
|
||||||
|
if feature_lower.startswith("arcticdem"):
|
||||||
|
return "ArcticDEM"
|
||||||
|
if feature_lower.startswith("era5"):
|
||||||
|
return "ERA5"
|
||||||
|
if feature_lower.startswith("embeddings") or feature_lower.startswith("alphaearth"):
|
||||||
|
return "Embeddings"
|
||||||
|
return "General"
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_feature_importance_data(
|
||||||
|
training_results: list[TrainingResult],
|
||||||
|
autogluon_results: list[AutogluonTrainingResult],
|
||||||
|
) -> pd.DataFrame | None:
|
||||||
|
"""Extract and prepare feature importance data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
training_results: List of RandomSearchCV training results
|
||||||
|
autogluon_results: List of AutoGluon training results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with feature importance data or None if no data available
|
||||||
|
|
||||||
|
"""
|
||||||
|
fi_df_cv = _extract_feature_importance_from_results(training_results)
|
||||||
|
fi_df_ag = _extract_feature_importance_from_autogluon(autogluon_results)
|
||||||
|
|
||||||
|
if fi_df_cv.empty and fi_df_ag.empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Combine both
|
||||||
|
fi_df = pd.concat([fi_df_cv, fi_df_ag], ignore_index=True)
|
||||||
|
|
||||||
|
# Add data source categorization
|
||||||
|
fi_df["data_source"] = fi_df["feature"].apply(_categorize_feature)
|
||||||
|
fi_df["grid_level"] = fi_df["grid"] + "_" + fi_df["level"].astype(str)
|
||||||
|
|
||||||
|
return fi_df
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_feature_importance_analysis(
|
||||||
|
training_results: list[TrainingResult],
|
||||||
|
autogluon_results: list[AutogluonTrainingResult],
|
||||||
|
):
|
||||||
|
"""Render feature importance analysis section.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
training_results: List of RandomSearchCV training results
|
||||||
|
autogluon_results: List of AutoGluon training results
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.header("🔍 Feature Importance Analysis")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
This section analyzes which features are most important across different
|
||||||
|
models, grid levels, tasks, and targets.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract feature importance
|
||||||
|
with st.spinner("Extracting feature importance from training results..."):
|
||||||
|
fi_df = _prepare_feature_importance_data(training_results, autogluon_results)
|
||||||
|
|
||||||
|
if fi_df is None:
|
||||||
|
st.warning("No feature importance data available. Model state files may be missing.")
|
||||||
|
return
|
||||||
|
|
||||||
|
st.success(f"Extracted feature importance from {len(fi_df)} feature-model combinations")
|
||||||
|
|
||||||
|
# Filters
|
||||||
|
st.subheader("Filters")
|
||||||
|
col1, col2, col3 = st.columns(3)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
# Task filter
|
||||||
|
available_tasks = ["All", *sorted(fi_df["task"].unique().tolist())]
|
||||||
|
selected_task = st.selectbox("Task", options=available_tasks, index=0, key="fi_task_filter")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
# Target filter
|
||||||
|
available_targets = ["All", *sorted(fi_df["target"].unique().tolist())]
|
||||||
|
selected_target = st.selectbox("Target Dataset", options=available_targets, index=0, key="fi_target_filter")
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
# Top N features
|
||||||
|
top_n_features = st.number_input("Top N Features", min_value=5, max_value=50, value=15, key="top_n_features")
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
filtered_fi_df = fi_df.copy()
|
||||||
|
if selected_task != "All":
|
||||||
|
filtered_fi_df = filtered_fi_df.loc[filtered_fi_df["task"] == selected_task]
|
||||||
|
if selected_target != "All":
|
||||||
|
filtered_fi_df = filtered_fi_df.loc[filtered_fi_df["target"] == selected_target]
|
||||||
|
|
||||||
|
if len(filtered_fi_df) == 0:
|
||||||
|
st.warning("No feature importance data available for the selected filters.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Section 1: Top features by grid level
|
||||||
|
st.subheader("Top Features by Grid Level")
|
||||||
|
|
||||||
|
try:
|
||||||
|
fig = create_feature_importance_by_grid_level(filtered_fi_df, top_n=top_n_features)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Could not create feature importance by grid level plot: {e}")
|
||||||
|
|
||||||
|
# Show detailed breakdown in expander
|
||||||
|
grid_levels = sorted(filtered_fi_df["grid_level"].unique())
|
||||||
|
|
||||||
|
with st.expander("Show Detailed Breakdown by Grid Level", expanded=False):
|
||||||
|
for grid_level in grid_levels:
|
||||||
|
grid_data = filtered_fi_df[filtered_fi_df["grid_level"] == grid_level]
|
||||||
|
|
||||||
|
# Get top features for this grid level
|
||||||
|
top_features_grid = (
|
||||||
|
grid_data.groupby("feature")["importance"].mean().reset_index().nlargest(top_n_features, "importance")
|
||||||
|
)
|
||||||
|
|
||||||
|
st.markdown(f"**{grid_level.replace('_', '-').title()}**")
|
||||||
|
|
||||||
|
# Create display dataframe with data source
|
||||||
|
display_df = top_features_grid.merge(
|
||||||
|
grid_data[["feature", "data_source"]].drop_duplicates(), on="feature", how="left"
|
||||||
|
)
|
||||||
|
display_df.columns = ["Feature", "Mean Importance", "Data Source"]
|
||||||
|
display_df = display_df.sort_values("Mean Importance", ascending=False)
|
||||||
|
|
||||||
|
st.dataframe(display_df, width="stretch", hide_index=True)
|
||||||
|
|
||||||
|
# Section 2: Feature importance consistency across models
|
||||||
|
st.subheader("Feature Importance Consistency Across Models")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
**Coefficient of Variation (CV)**: Lower values indicate more consistent importance across models.
|
||||||
|
High CV suggests the feature's importance varies significantly between different models.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
fig = create_feature_consistency_plot(filtered_fi_df, top_n=top_n_features)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Could not create feature consistency plot: {e}")
|
||||||
|
|
||||||
|
# Show detailed statistics in expander
|
||||||
|
with st.expander("Show Detailed Statistics", expanded=False):
|
||||||
|
# Get top features overall
|
||||||
|
overall_top_features = (
|
||||||
|
filtered_fi_df.groupby("feature")["importance"]
|
||||||
|
.mean()
|
||||||
|
.reset_index()
|
||||||
|
.nlargest(top_n_features, "importance")["feature"]
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate variance in importance across models for each feature
|
||||||
|
feature_variance = (
|
||||||
|
filtered_fi_df[filtered_fi_df["feature"].isin(overall_top_features)]
|
||||||
|
.groupby("feature")["importance"]
|
||||||
|
.agg(["mean", "std", "min", "max"])
|
||||||
|
.reset_index()
|
||||||
|
)
|
||||||
|
feature_variance["coefficient_of_variation"] = feature_variance["std"] / feature_variance["mean"]
|
||||||
|
feature_variance = feature_variance.sort_values("mean", ascending=False)
|
||||||
|
|
||||||
|
# Add data source
|
||||||
|
feature_variance = feature_variance.merge(
|
||||||
|
filtered_fi_df[["feature", "data_source"]].drop_duplicates(), on="feature", how="left"
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_variance.columns = ["Feature", "Mean", "Std Dev", "Min", "Max", "CV", "Data Source"]
|
||||||
|
|
||||||
|
st.dataframe(
|
||||||
|
feature_variance[["Feature", "Data Source", "Mean", "Std Dev", "CV"]],
|
||||||
|
width="stretch",
|
||||||
|
hide_index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Section 3: Feature importance by data source
|
||||||
|
st.subheader("Feature Importance by Data Source")
|
||||||
|
|
||||||
|
try:
|
||||||
|
fig = create_data_source_importance_bars(filtered_fi_df)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Could not create data source importance chart: {e}")
|
||||||
|
|
||||||
|
# Show detailed table in expander
|
||||||
|
with st.expander("Show Data Source Statistics", expanded=False):
|
||||||
|
# Aggregate importance by data source
|
||||||
|
source_importance = (
|
||||||
|
filtered_fi_df.groupby("data_source")["importance"].agg(["sum", "mean", "count"]).reset_index()
|
||||||
|
)
|
||||||
|
source_importance.columns = ["Data Source", "Total Importance", "Mean Importance", "Feature Count"]
|
||||||
|
source_importance = source_importance.sort_values("Total Importance", ascending=False)
|
||||||
|
|
||||||
|
st.dataframe(source_importance, width="stretch", hide_index=True)
|
||||||
75
src/entropice/dashboard/sections/experiment_grid_analysis.py
Normal file
75
src/entropice/dashboard/sections/experiment_grid_analysis.py
Normal file
|
|
@ -0,0 +1,75 @@
|
||||||
|
"""Grid-level analysis section for experiment comparison."""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.plots.experiment_comparison import create_grid_level_comparison_plot
|
||||||
|
from entropice.dashboard.utils.formatters import format_metric_name
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_grid_level_analysis(summary_df: pd.DataFrame, available_metrics: list[str]):
|
||||||
|
"""Render grid-level analysis section.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
summary_df: Summary DataFrame with all results
|
||||||
|
available_metrics: List of available metrics
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.header("📐 Grid Level Analysis")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
This section analyzes how different grid levels affect model performance.
|
||||||
|
Compare performance across grid types (hex vs healpix) and resolution levels.
|
||||||
|
Metrics are automatically selected based on the task type.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine metrics to use per task
|
||||||
|
task_metric_map = {
|
||||||
|
"binary": "f1",
|
||||||
|
"count_regimes": "f1_weighted",
|
||||||
|
"density_regimes": "f1_weighted",
|
||||||
|
"count": "r2",
|
||||||
|
"density": "r2",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get unique tasks in the data
|
||||||
|
unique_tasks = summary_df["task"].unique()
|
||||||
|
|
||||||
|
# Split selection
|
||||||
|
split = st.selectbox("Data Split", options=["test", "train", "combined"], index=0, key="grid_split")
|
||||||
|
|
||||||
|
# Create plots for each task
|
||||||
|
for task in sorted(unique_tasks):
|
||||||
|
# Determine the metric for this task
|
||||||
|
metric = task_metric_map.get(task, available_metrics[0])
|
||||||
|
|
||||||
|
# Check if metric is available
|
||||||
|
metric_col = f"{split}_{metric}"
|
||||||
|
if metric_col not in summary_df.columns:
|
||||||
|
# Fall back to first available metric
|
||||||
|
metric = available_metrics[0]
|
||||||
|
metric_col = f"{split}_{metric}"
|
||||||
|
|
||||||
|
st.subheader(f"{task.replace('_', ' ').title()} - {format_metric_name(metric)}")
|
||||||
|
|
||||||
|
# Filter data for this task
|
||||||
|
task_df = summary_df[summary_df["task"] == task]
|
||||||
|
|
||||||
|
# Create grid-level comparison plot
|
||||||
|
try:
|
||||||
|
fig = create_grid_level_comparison_plot(task_df, metric, split)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Could not create grid-level comparison plot for {task}: {e}")
|
||||||
|
|
||||||
|
# Show statistics by grid level for this task
|
||||||
|
if metric_col in task_df.columns:
|
||||||
|
stats = (
|
||||||
|
task_df.groupby(["grid", "level"])[metric_col].agg(["mean", "std", "min", "max", "count"]).reset_index()
|
||||||
|
)
|
||||||
|
stats.columns = ["Grid", "Level", "Mean", "Std Dev", "Min", "Max", "Count"]
|
||||||
|
with st.expander(f"Show {format_metric_name(metric)} Statistics by Grid Level", expanded=False):
|
||||||
|
st.dataframe(stats.sort_values("Mean", ascending=False), width="stretch", hide_index=True)
|
||||||
193
src/entropice/dashboard/sections/experiment_inference_maps.py
Normal file
193
src/entropice/dashboard/sections/experiment_inference_maps.py
Normal file
|
|
@ -0,0 +1,193 @@
|
||||||
|
"""Section for visualizing experiment inference maps."""
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.plots.experiment_comparison import create_inference_maps
|
||||||
|
from entropice.dashboard.utils.loaders import TrainingResult
|
||||||
|
from entropice.utils.types import GridConfig
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_inference_maps_section(
|
||||||
|
experiment_name: str,
|
||||||
|
training_results: list[TrainingResult],
|
||||||
|
) -> None:
|
||||||
|
"""Render the inference maps section.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
experiment_name: Name of the experiment
|
||||||
|
training_results: List of training results for the experiment
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.header("🗺️ Inference Maps Analysis")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
Visualize the mean and uncertainty (standard deviation) of model predictions across the Arctic region.
|
||||||
|
The maps show the spatial distribution of predictions aggregated across multiple training runs.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
if not training_results:
|
||||||
|
st.info("No training results available for inference maps.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract unique grid configurations from training results
|
||||||
|
available_grid_configs = sorted(
|
||||||
|
{GridConfig.from_grid_level((tr.settings.grid, tr.settings.level)) for tr in training_results},
|
||||||
|
key=lambda gc: gc.sort_key,
|
||||||
|
)
|
||||||
|
available_tasks = sorted({tr.settings.task for tr in training_results})
|
||||||
|
available_targets = sorted({tr.settings.target for tr in training_results})
|
||||||
|
available_models = sorted({tr.settings.model for tr in training_results})
|
||||||
|
|
||||||
|
# Create form for selecting parameters
|
||||||
|
with st.form("inference_map_form"):
|
||||||
|
st.subheader("Map Configuration")
|
||||||
|
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
selected_grid_config = st.selectbox(
|
||||||
|
"Grid Configuration",
|
||||||
|
options=available_grid_configs,
|
||||||
|
format_func=lambda gc: gc.display_name,
|
||||||
|
help="Select the grid type and resolution level for the inference map",
|
||||||
|
)
|
||||||
|
|
||||||
|
selected_task = st.selectbox(
|
||||||
|
"Task",
|
||||||
|
options=available_tasks,
|
||||||
|
help="Select the prediction task",
|
||||||
|
)
|
||||||
|
|
||||||
|
st.subheader("Filters")
|
||||||
|
|
||||||
|
col3, col4 = st.columns(2)
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
selected_targets = st.multiselect(
|
||||||
|
"Target Datasets",
|
||||||
|
options=available_targets,
|
||||||
|
default=available_targets,
|
||||||
|
help="Filter by target datasets (select all to include all)",
|
||||||
|
)
|
||||||
|
|
||||||
|
with col4:
|
||||||
|
selected_models = st.multiselect(
|
||||||
|
"Model Types",
|
||||||
|
options=available_models,
|
||||||
|
default=available_models,
|
||||||
|
help="Filter by model types (select all to include all)",
|
||||||
|
)
|
||||||
|
|
||||||
|
submit_button = st.form_submit_button("Generate Maps", type="primary")
|
||||||
|
|
||||||
|
if submit_button:
|
||||||
|
# Extract grid and level from selected config
|
||||||
|
selected_grid = selected_grid_config.grid
|
||||||
|
selected_level = selected_grid_config.level
|
||||||
|
|
||||||
|
# Filter training results based on selections
|
||||||
|
filtered_results = [
|
||||||
|
tr
|
||||||
|
for tr in training_results
|
||||||
|
if tr.settings.grid == selected_grid
|
||||||
|
and tr.settings.level == selected_level
|
||||||
|
and tr.settings.task == selected_task
|
||||||
|
and tr.settings.target in selected_targets
|
||||||
|
and tr.settings.model in selected_models
|
||||||
|
]
|
||||||
|
|
||||||
|
if not filtered_results:
|
||||||
|
st.warning("No training results match the selected criteria. Please adjust your selections and try again.")
|
||||||
|
return
|
||||||
|
|
||||||
|
st.success(f"Found {len(filtered_results)} training runs matching the criteria.")
|
||||||
|
|
||||||
|
# Display metadata about selected runs
|
||||||
|
with st.expander("View Selected Training Runs"):
|
||||||
|
run_info = []
|
||||||
|
for tr in filtered_results:
|
||||||
|
info = tr.display_info
|
||||||
|
run_info.append(
|
||||||
|
{
|
||||||
|
"Task": info.task,
|
||||||
|
"Target": info.target,
|
||||||
|
"Model": info.model,
|
||||||
|
"Grid": f"{info.grid}_{info.level}",
|
||||||
|
"Created": info.timestamp.strftime("%Y-%m-%d %H:%M"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
st.dataframe(run_info, width="stretch")
|
||||||
|
|
||||||
|
# Calculate inference maps
|
||||||
|
with st.spinner("Calculating inference maps from predictions..."):
|
||||||
|
try:
|
||||||
|
from entropice.dashboard.utils.loaders import TrainingResult
|
||||||
|
|
||||||
|
inference_gdf = TrainingResult.calculate_inference_maps(filtered_results)
|
||||||
|
|
||||||
|
st.info(
|
||||||
|
f"Generated inference map with {len(inference_gdf):,} grid cells. "
|
||||||
|
+ (
|
||||||
|
"Using point-based rendering (>50k cells)."
|
||||||
|
if len(inference_gdf) > 50000
|
||||||
|
else "Using polygon-based rendering."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except AssertionError as e:
|
||||||
|
st.error(f"Error calculating inference maps: {e}")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Unexpected error calculating inference maps: {e}")
|
||||||
|
st.exception(e)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Generate maps
|
||||||
|
with st.spinner("Generating cartographic visualizations..."):
|
||||||
|
try:
|
||||||
|
deck_mean, deck_std = create_inference_maps(
|
||||||
|
inference_gdf,
|
||||||
|
grid=selected_grid,
|
||||||
|
level=selected_level,
|
||||||
|
task=selected_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display maps
|
||||||
|
st.subheader("📍 Mean Prediction Map")
|
||||||
|
st.pydeck_chart(deck_mean, use_container_width=True)
|
||||||
|
|
||||||
|
st.subheader("📍 Uncertainty Map (Standard Deviation)")
|
||||||
|
st.pydeck_chart(deck_std, use_container_width=True)
|
||||||
|
|
||||||
|
# Display statistics
|
||||||
|
st.subheader("📊 Inference Statistics")
|
||||||
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.metric("Grid Cells", f"{len(inference_gdf):,}")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.metric(
|
||||||
|
"Mean Prediction",
|
||||||
|
f"{inference_gdf['mean_prediction'].mean():.4f}",
|
||||||
|
)
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
st.metric(
|
||||||
|
"Avg Uncertainty",
|
||||||
|
f"{inference_gdf['std_prediction'].mean():.4f}",
|
||||||
|
)
|
||||||
|
|
||||||
|
with col4:
|
||||||
|
st.metric(
|
||||||
|
"Max Uncertainty",
|
||||||
|
f"{inference_gdf['std_prediction'].max():.4f}",
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Error generating maps: {e}")
|
||||||
|
st.exception(e)
|
||||||
|
return
|
||||||
|
|
@ -0,0 +1,77 @@
|
||||||
|
"""Model comparison section for experiment analysis."""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.plots.experiment_comparison import create_top_models_bar_chart
|
||||||
|
from entropice.dashboard.utils.formatters import format_metric_name
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_model_comparison(summary_df: pd.DataFrame, available_metrics: list[str]):
|
||||||
|
"""Render model comparison section.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
summary_df: Summary DataFrame with all results
|
||||||
|
available_metrics: List of available metrics
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.header("🏆 Model Performance Comparison")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
This section shows the best performing models for each task.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine metrics to use per task
|
||||||
|
task_metric_map = {
|
||||||
|
"binary": "f1",
|
||||||
|
"count_regimes": "f1_weighted",
|
||||||
|
"density_regimes": "f1_weighted",
|
||||||
|
"count": "r2",
|
||||||
|
"density": "r2",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Split selection
|
||||||
|
split = st.selectbox("Data Split", options=["test", "train", "combined"], index=0, key="model_split")
|
||||||
|
|
||||||
|
# Get unique tasks
|
||||||
|
unique_tasks = summary_df["task"].unique()
|
||||||
|
|
||||||
|
# For each task, show the best models
|
||||||
|
for task in sorted(unique_tasks):
|
||||||
|
# Determine the metric for this task
|
||||||
|
metric = task_metric_map.get(task, available_metrics[0])
|
||||||
|
metric_col = f"{split}_{metric}"
|
||||||
|
|
||||||
|
if metric_col not in summary_df.columns:
|
||||||
|
# Fall back to first available metric
|
||||||
|
metric = available_metrics[0]
|
||||||
|
metric_col = f"{split}_{metric}"
|
||||||
|
|
||||||
|
st.subheader(f"{task.replace('_', ' ').title()}")
|
||||||
|
|
||||||
|
# Filter data for this task
|
||||||
|
task_df = summary_df[summary_df["task"] == task].copy()
|
||||||
|
|
||||||
|
# Create visualization
|
||||||
|
try:
|
||||||
|
fig = create_top_models_bar_chart(task_df, metric, task, split, top_n=10)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Could not create model comparison chart for {task}: {e}")
|
||||||
|
|
||||||
|
# Show table in expander
|
||||||
|
with st.expander("Show Top 10 Models Table", expanded=False):
|
||||||
|
top_models = task_df.nlargest(10, metric_col)[
|
||||||
|
["model", "grid", "level", "target", metric_col, "method"]
|
||||||
|
].copy()
|
||||||
|
top_models.columns = ["Model", "Grid", "Level", "Target", format_metric_name(metric), "Method"]
|
||||||
|
st.dataframe(
|
||||||
|
top_models.reset_index(drop=True),
|
||||||
|
width="stretch",
|
||||||
|
hide_index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
92
src/entropice/dashboard/sections/experiment_overview.py
Normal file
92
src/entropice/dashboard/sections/experiment_overview.py
Normal file
|
|
@ -0,0 +1,92 @@
|
||||||
|
"""Experiment overview and sidebar sections."""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.utils.loaders import (
|
||||||
|
AutogluonTrainingResult,
|
||||||
|
TrainingResult,
|
||||||
|
get_available_experiments,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_experiment_sidebar() -> str | None:
|
||||||
|
"""Render sidebar for experiment selection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Selected experiment name or None
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.sidebar.header("🔬 Experiment Selection")
|
||||||
|
|
||||||
|
experiments = get_available_experiments()
|
||||||
|
|
||||||
|
if not experiments:
|
||||||
|
st.sidebar.warning("No experiments found. Create an experiment directory with training results first.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
selected_experiment = st.sidebar.selectbox(
|
||||||
|
"Select Experiment",
|
||||||
|
options=experiments,
|
||||||
|
index=0,
|
||||||
|
help="Choose an experiment to analyze",
|
||||||
|
)
|
||||||
|
|
||||||
|
return selected_experiment
|
||||||
|
|
||||||
|
|
||||||
|
def render_experiment_overview(
|
||||||
|
experiment_name: str,
|
||||||
|
training_results: list[TrainingResult],
|
||||||
|
autogluon_results: list[AutogluonTrainingResult],
|
||||||
|
summary_df: pd.DataFrame,
|
||||||
|
):
|
||||||
|
"""Render experiment overview section.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
experiment_name: Name of the experiment
|
||||||
|
training_results: List of RandomSearchCV training results
|
||||||
|
autogluon_results: List of AutoGluon training results
|
||||||
|
summary_df: Summary DataFrame with all results
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.header(f"📊 Experiment: {experiment_name}")
|
||||||
|
|
||||||
|
# Show summary statistics
|
||||||
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.metric("Total Training Runs", len(training_results) + len(autogluon_results))
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.metric("RandomSearchCV Runs", len(training_results))
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
st.metric("AutoGluon Runs", len(autogluon_results))
|
||||||
|
|
||||||
|
with col4:
|
||||||
|
unique_configs = summary_df[["grid", "level", "task", "target"]].drop_duplicates()
|
||||||
|
st.metric("Unique Configurations", len(unique_configs))
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
# Show summary table
|
||||||
|
st.subheader("Experiment Summary")
|
||||||
|
|
||||||
|
display_columns = [
|
||||||
|
"method",
|
||||||
|
"task",
|
||||||
|
"target",
|
||||||
|
"model",
|
||||||
|
"grid_level",
|
||||||
|
"test_score",
|
||||||
|
"best_metric",
|
||||||
|
"n_trials",
|
||||||
|
]
|
||||||
|
available_columns = [col for col in display_columns if col in summary_df.columns]
|
||||||
|
|
||||||
|
st.dataframe(
|
||||||
|
summary_df[available_columns].sort_values("test_score", ascending=False),
|
||||||
|
width="stretch",
|
||||||
|
hide_index=True,
|
||||||
|
)
|
||||||
0
src/entropice/dashboard/utils/__init__.py
Normal file
0
src/entropice/dashboard/utils/__init__.py
Normal file
|
|
@ -8,6 +8,7 @@ from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import antimeridian
|
import antimeridian
|
||||||
|
import geopandas as gpd
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import toml
|
import toml
|
||||||
|
|
@ -19,7 +20,7 @@ import entropice.utils.paths
|
||||||
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
|
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
|
||||||
from entropice.ml.autogluon_training import AutoGluonTrainingSettings
|
from entropice.ml.autogluon_training import AutoGluonTrainingSettings
|
||||||
from entropice.ml.dataset import DatasetEnsemble, TrainingSet
|
from entropice.ml.dataset import DatasetEnsemble, TrainingSet
|
||||||
from entropice.ml.training import TrainingSettings
|
from entropice.ml.randomsearch import TrainingSettings
|
||||||
from entropice.utils.types import GridConfig, TargetDataset, Task, all_target_datasets, all_tasks
|
from entropice.utils.types import GridConfig, TargetDataset, Task, all_target_datasets, all_tasks
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -215,6 +216,37 @@ class TrainingResult:
|
||||||
records.append(record)
|
records.append(record)
|
||||||
return pd.DataFrame.from_records(records)
|
return pd.DataFrame.from_records(records)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_inference_maps(training_results: list["TrainingResult"]) -> gpd.GeoDataFrame:
|
||||||
|
"""Calculate the mean and standard deviation of inference maps across multiple training results."""
|
||||||
|
assert len({tr.settings.grid for tr in training_results}) == 1, "All training results must have the same grid"
|
||||||
|
assert len({tr.settings.level for tr in training_results}) == 1, "All training results must have the same level"
|
||||||
|
|
||||||
|
grid = training_results[0].settings.grid
|
||||||
|
level = training_results[0].settings.level
|
||||||
|
gridfile = entropice.utils.paths.get_grid_file(grid, level)
|
||||||
|
cells = gpd.read_parquet(gridfile, columns=["cell_id", "geometry"])
|
||||||
|
if grid == "hex":
|
||||||
|
cells["cell_id"] = cells["cell_id"].apply(lambda x: int(x, 16))
|
||||||
|
cells = cells.set_index("cell_id")
|
||||||
|
|
||||||
|
vals = []
|
||||||
|
for tr in training_results:
|
||||||
|
preds_file = tr.path / "predicted_probabilities.parquet"
|
||||||
|
if not preds_file.exists():
|
||||||
|
continue
|
||||||
|
preds = pd.read_parquet(preds_file, columns=["cell_id", "predicted"]).set_index("cell_id")
|
||||||
|
if preds["predicted"].dtype == "category":
|
||||||
|
preds["predicted"] = preds["predicted"].cat.codes
|
||||||
|
vals.append(preds)
|
||||||
|
all_preds = pd.concat(vals, axis=1)
|
||||||
|
mean_preds = all_preds.mean(axis=1)
|
||||||
|
std_preds = all_preds.std(axis=1)
|
||||||
|
cells["mean_prediction"] = mean_preds
|
||||||
|
cells["std_prediction"] = std_preds
|
||||||
|
|
||||||
|
return cells.reset_index()
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data(ttl=300) # Cache for 5 minutes
|
@st.cache_data(ttl=300) # Cache for 5 minutes
|
||||||
def load_all_training_results() -> list[TrainingResult]:
|
def load_all_training_results() -> list[TrainingResult]:
|
||||||
|
|
@ -410,6 +442,185 @@ def load_training_sets(ensemble: DatasetEnsemble) -> dict[TargetDataset, dict[Ta
|
||||||
return train_data_dict
|
return train_data_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_experiments() -> list[str]:
|
||||||
|
"""Get list of available experiment names from the results directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of experiment directory names
|
||||||
|
|
||||||
|
"""
|
||||||
|
results_dir = entropice.utils.paths.RESULTS_DIR
|
||||||
|
experiments = []
|
||||||
|
|
||||||
|
for item in results_dir.iterdir():
|
||||||
|
if not item.is_dir():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if this directory contains training results (has subdirs with training results files)
|
||||||
|
has_training_results = False
|
||||||
|
for subitem in item.iterdir():
|
||||||
|
if not subitem.is_dir():
|
||||||
|
continue
|
||||||
|
# Check for either RandomSearchCV or AutoGluon result files
|
||||||
|
if (subitem / "search_results.parquet").exists() or (subitem / "leaderboard.parquet").exists():
|
||||||
|
has_training_results = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if has_training_results:
|
||||||
|
experiments.append(item.name)
|
||||||
|
|
||||||
|
return sorted(experiments)
|
||||||
|
|
||||||
|
|
||||||
|
def load_experiment_training_results(experiment_name: str) -> list[TrainingResult]:
|
||||||
|
"""Load all training results for a specific experiment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
experiment_name: Name of the experiment directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of TrainingResult objects for the experiment
|
||||||
|
|
||||||
|
"""
|
||||||
|
experiment_dir = entropice.utils.paths.RESULTS_DIR / experiment_name
|
||||||
|
if not experiment_dir.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
training_results: list[TrainingResult] = []
|
||||||
|
for result_path in experiment_dir.iterdir():
|
||||||
|
if not result_path.is_dir():
|
||||||
|
continue
|
||||||
|
# Skip AutoGluon results
|
||||||
|
if "autogluon" in result_path.name.lower():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
training_result = TrainingResult.from_path(result_path, experiment_name)
|
||||||
|
training_results.append(training_result)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass # Skip incomplete results
|
||||||
|
|
||||||
|
# Sort by creation time (most recent first)
|
||||||
|
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
|
||||||
|
return training_results
|
||||||
|
|
||||||
|
|
||||||
|
def load_experiment_autogluon_results(experiment_name: str) -> list[AutogluonTrainingResult]:
|
||||||
|
"""Load all AutoGluon training results for a specific experiment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
experiment_name: Name of the experiment directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of AutogluonTrainingResult objects for the experiment
|
||||||
|
|
||||||
|
"""
|
||||||
|
experiment_dir = entropice.utils.paths.RESULTS_DIR / experiment_name
|
||||||
|
if not experiment_dir.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
training_results: list[AutogluonTrainingResult] = []
|
||||||
|
for result_path in experiment_dir.iterdir():
|
||||||
|
if not result_path.is_dir():
|
||||||
|
continue
|
||||||
|
# Only include AutoGluon results
|
||||||
|
if "autogluon" not in result_path.name.lower():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
training_result = AutogluonTrainingResult.from_path(result_path, experiment_name)
|
||||||
|
training_results.append(training_result)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass # Skip incomplete results
|
||||||
|
|
||||||
|
# Sort by creation time (most recent first)
|
||||||
|
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
|
||||||
|
return training_results
|
||||||
|
|
||||||
|
|
||||||
|
def create_experiment_summary_df(
|
||||||
|
training_results: list[TrainingResult], autogluon_results: list[AutogluonTrainingResult]
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Create a summary DataFrame for all results in an experiment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
training_results: List of TrainingResult objects
|
||||||
|
autogluon_results: List of AutogluonTrainingResult objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with summary statistics for the experiment
|
||||||
|
|
||||||
|
"""
|
||||||
|
records = []
|
||||||
|
|
||||||
|
# Add RandomSearchCV results
|
||||||
|
for tr in training_results:
|
||||||
|
info = tr.display_info
|
||||||
|
best_metric_name = tr._get_best_metric_name()
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"method": "RandomSearchCV",
|
||||||
|
"task": info.task,
|
||||||
|
"target": info.target,
|
||||||
|
"model": info.model,
|
||||||
|
"grid": info.grid,
|
||||||
|
"level": info.level,
|
||||||
|
"grid_level": f"{info.grid}_{info.level}",
|
||||||
|
"train_score": tr.train_metrics.get(best_metric_name, float("nan")),
|
||||||
|
"test_score": tr.test_metrics.get(best_metric_name, float("nan")),
|
||||||
|
"combined_score": tr.combined_metrics.get(best_metric_name, float("nan")),
|
||||||
|
"best_metric": best_metric_name,
|
||||||
|
"n_trials": len(tr.results),
|
||||||
|
"created_at": tr.created_at,
|
||||||
|
"path": tr.path,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add all train metrics
|
||||||
|
for metric, value in tr.train_metrics.items():
|
||||||
|
record[f"train_{metric}"] = value
|
||||||
|
|
||||||
|
# Add all test metrics
|
||||||
|
for metric, value in tr.test_metrics.items():
|
||||||
|
record[f"test_{metric}"] = value
|
||||||
|
|
||||||
|
# Add all combined metrics
|
||||||
|
for metric, value in tr.combined_metrics.items():
|
||||||
|
record[f"combined_{metric}"] = value
|
||||||
|
|
||||||
|
records.append(record)
|
||||||
|
|
||||||
|
# Add AutoGluon results
|
||||||
|
for ag in autogluon_results:
|
||||||
|
info = ag.display_info
|
||||||
|
best_metric_name = ag._get_best_metric_name()
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"method": "AutoGluon",
|
||||||
|
"task": info.task,
|
||||||
|
"target": info.target,
|
||||||
|
"model": "ensemble", # AutoGluon is an ensemble
|
||||||
|
"grid": info.grid,
|
||||||
|
"level": info.level,
|
||||||
|
"grid_level": f"{info.grid}_{info.level}",
|
||||||
|
"train_score": float("nan"), # AutoGluon doesn't separate train scores
|
||||||
|
"test_score": ag.test_metrics.get(best_metric_name, float("nan")),
|
||||||
|
"combined_score": float("nan"),
|
||||||
|
"best_metric": best_metric_name,
|
||||||
|
"n_trials": len(ag.leaderboard),
|
||||||
|
"created_at": ag.created_at,
|
||||||
|
"path": ag.path,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add test metrics
|
||||||
|
for metric, value in ag.test_metrics.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
record[f"test_{metric}"] = value
|
||||||
|
|
||||||
|
records.append(record)
|
||||||
|
|
||||||
|
return pd.DataFrame(records)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StorageInfo:
|
class StorageInfo:
|
||||||
"""Storage information for a directory."""
|
"""Storage information for a directory."""
|
||||||
|
|
|
||||||
0
src/entropice/dashboard/views/__init__.py
Normal file
0
src/entropice/dashboard/views/__init__.py
Normal file
|
|
@ -138,6 +138,7 @@ def render_dataset_page():
|
||||||
)
|
)
|
||||||
if era5_members:
|
if era5_members:
|
||||||
tab_names.append("🌡️ ERA5")
|
tab_names.append("🌡️ ERA5")
|
||||||
|
tab_names += ["🔗 Correlations"]
|
||||||
tabs = st.tabs(tab_names)
|
tabs = st.tabs(tab_names)
|
||||||
|
|
||||||
with tabs[0]:
|
with tabs[0]:
|
||||||
|
|
@ -168,6 +169,12 @@ def render_dataset_page():
|
||||||
era5_member_dataset = {m: member_datasets[m] for m in era5_members}
|
era5_member_dataset = {m: member_datasets[m] for m in era5_members}
|
||||||
era5_member_stats = {m: stats.members[m] for m in era5_members}
|
era5_member_stats = {m: stats.members[m] for m in era5_members}
|
||||||
render_era5_tab(era5_member_dataset, grid_gdf, era5_member_stats)
|
render_era5_tab(era5_member_dataset, grid_gdf, era5_member_stats)
|
||||||
|
tab_index += 1
|
||||||
|
|
||||||
|
with tabs[tab_index]:
|
||||||
|
st.header("🔗 Cross-Dataset Correlation Analysis")
|
||||||
|
# Extract grid area series
|
||||||
|
# grid_area_series = grid_gdf.set_index("cell_id")["area_km2"] if "area_km2" in grid_gdf.columns else None
|
||||||
|
# render_correlations_tab(member_datasets, grid_area_series=grid_area_series)
|
||||||
st.balloons()
|
st.balloons()
|
||||||
stopwatch.summary()
|
stopwatch.summary()
|
||||||
|
|
|
||||||
87
src/entropice/dashboard/views/experiment_analysis_page.py
Normal file
87
src/entropice/dashboard/views/experiment_analysis_page.py
Normal file
|
|
@ -0,0 +1,87 @@
|
||||||
|
"""Experiment Analysis page: Compare multiple training runs within an experiment."""
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.sections.experiment_feature_importance import render_feature_importance_analysis
|
||||||
|
from entropice.dashboard.sections.experiment_grid_analysis import render_grid_level_analysis
|
||||||
|
from entropice.dashboard.sections.experiment_inference_maps import render_inference_maps_section
|
||||||
|
from entropice.dashboard.sections.experiment_model_comparison import render_model_comparison
|
||||||
|
from entropice.dashboard.sections.experiment_overview import (
|
||||||
|
render_experiment_overview,
|
||||||
|
render_experiment_sidebar,
|
||||||
|
)
|
||||||
|
from entropice.dashboard.utils.loaders import (
|
||||||
|
create_experiment_summary_df,
|
||||||
|
load_experiment_autogluon_results,
|
||||||
|
load_experiment_training_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_experiment_analysis_page():
|
||||||
|
"""Render the Experiment Analysis page of the dashboard."""
|
||||||
|
st.title("🔬 Experiment Analysis")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
Analyze and compare multiple training runs within an experiment.
|
||||||
|
Select an experiment from the sidebar to explore:
|
||||||
|
- How grid levels affect performance
|
||||||
|
- Best models across different tasks and targets
|
||||||
|
- Feature importance patterns
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Experiment selection
|
||||||
|
selected_experiment = render_experiment_sidebar()
|
||||||
|
|
||||||
|
if selected_experiment is None:
|
||||||
|
st.info("👈 Select an experiment from the sidebar to begin analysis.")
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
assert selected_experiment is not None, "Experiment must be selected"
|
||||||
|
|
||||||
|
# Load experiment results
|
||||||
|
with st.spinner(f"Loading results for experiment: {selected_experiment}..."):
|
||||||
|
training_results = load_experiment_training_results(selected_experiment)
|
||||||
|
autogluon_results = load_experiment_autogluon_results(selected_experiment)
|
||||||
|
|
||||||
|
if not training_results and not autogluon_results:
|
||||||
|
st.warning(f"No training results found in experiment: {selected_experiment}")
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
# Create summary DataFrame
|
||||||
|
summary_df = create_experiment_summary_df(training_results, autogluon_results)
|
||||||
|
|
||||||
|
# Get available metrics
|
||||||
|
metric_columns = [col for col in summary_df.columns if col.startswith("test_")]
|
||||||
|
available_metrics = [col.replace("test_", "") for col in metric_columns]
|
||||||
|
|
||||||
|
if not available_metrics:
|
||||||
|
st.error("No metrics found in the experiment results.")
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
# Render analysis sections
|
||||||
|
render_experiment_overview(selected_experiment, training_results, autogluon_results, summary_df)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
render_grid_level_analysis(summary_df, available_metrics)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
render_model_comparison(summary_df, available_metrics)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
render_feature_importance_analysis(training_results, autogluon_results)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
render_inference_maps_section(selected_experiment, training_results)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
# Raw data section
|
||||||
|
st.header("📄 Raw Experiment Data")
|
||||||
|
with st.expander("View Complete Summary DataFrame"):
|
||||||
|
st.dataframe(summary_df, width="stretch")
|
||||||
|
|
@ -85,7 +85,7 @@ def render_inference_statistics_section(predictions_gdf: gpd.GeoDataFrame, task:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
predictions_gdf: GeoDataFrame with predictions.
|
predictions_gdf: GeoDataFrame with predictions.
|
||||||
task: Task type ('binary', 'count', 'density').
|
task: Task type ('binary', 'count_regimes', 'density_regimes', 'count', 'density').
|
||||||
|
|
||||||
"""
|
"""
|
||||||
st.header("📊 Inference Summary")
|
st.header("📊 Inference Summary")
|
||||||
|
|
@ -125,7 +125,7 @@ def render_class_distribution_section(predictions_gdf: gpd.GeoDataFrame, task: s
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
predictions_gdf: GeoDataFrame with predictions.
|
predictions_gdf: GeoDataFrame with predictions.
|
||||||
task: Task type ('binary', 'count', 'density').
|
task: Task type ('binary', 'count_regimes', 'density_regimes', 'count', 'density').
|
||||||
|
|
||||||
"""
|
"""
|
||||||
st.header("📈 Class Distribution")
|
st.header("📈 Class Distribution")
|
||||||
|
|
@ -138,7 +138,7 @@ def render_class_comparison_section(predictions_gdf: gpd.GeoDataFrame, task: str
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
predictions_gdf: GeoDataFrame with predictions.
|
predictions_gdf: GeoDataFrame with predictions.
|
||||||
task: Task type ('binary', 'count', 'density').
|
task: Task type ('binary', 'count_regimes', 'density_regimes', 'count', 'density').
|
||||||
|
|
||||||
"""
|
"""
|
||||||
st.header("🔍 Class Comparison Analysis")
|
st.header("🔍 Class Comparison Analysis")
|
||||||
|
|
|
||||||
87
src/entropice/experiments/feature_importance.py
Normal file
87
src/entropice/experiments/feature_importance.py
Normal file
|
|
@ -0,0 +1,87 @@
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import cyclopts
|
||||||
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
|
from entropice.ml.autogluon import RunSettings as AutoGluonRunSettings
|
||||||
|
from entropice.ml.autogluon import train as train_autogluon
|
||||||
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
|
from entropice.ml.hpsearchcv import RunSettings as HPOCVRunSettings
|
||||||
|
from entropice.ml.hpsearchcv import hpsearch_cv
|
||||||
|
from entropice.utils.paths import RESULTS_DIR
|
||||||
|
from entropice.utils.types import Grid, Model, TargetDataset, Task
|
||||||
|
|
||||||
|
cli = cyclopts.App("entropice-feature-importance")
|
||||||
|
|
||||||
|
EXPERIMENT_NAME = "feature_importance_era5-shoulder_arcticdem"
|
||||||
|
# EXPERIMENT_NAME = "tobis-final-tests"
|
||||||
|
|
||||||
|
|
||||||
|
@cli.default
|
||||||
|
def main(
|
||||||
|
grid: Grid,
|
||||||
|
target: TargetDataset,
|
||||||
|
):
|
||||||
|
levels = [3, 4, 5, 6] if grid == "hex" else [6, 7, 8, 9, 10]
|
||||||
|
levels = [3, 6] if grid == "hex" else [6, 10]
|
||||||
|
for level in levels:
|
||||||
|
print(f"Running feature importance experiment for {grid} grid at level {level}...")
|
||||||
|
dimension_filters = {"ArcticDEM": {"aggregations": ["median"]}}
|
||||||
|
if (grid == "hex" and level in [3, 4]) or (grid == "healpix" and level in [6, 7]):
|
||||||
|
dimension_filters["ERA5-shoulder"] = {"aggregations": ["median"]}
|
||||||
|
dataset_ensemble = DatasetEnsemble(
|
||||||
|
grid=grid, level=level, members=["ArcticDEM", "ERA5-shoulder"], dimension_filters=dimension_filters
|
||||||
|
)
|
||||||
|
|
||||||
|
for task in cast(list[Task], ["binary", "density"]):
|
||||||
|
print(f"\nRunning for {task}...")
|
||||||
|
|
||||||
|
# AutoGluon
|
||||||
|
time_limit = 30 * 60 # 30 minutes
|
||||||
|
# time_limit = 60
|
||||||
|
presets = "extreme"
|
||||||
|
# presets = "medium"
|
||||||
|
settings = AutoGluonRunSettings(
|
||||||
|
time_limit=time_limit,
|
||||||
|
presets=presets,
|
||||||
|
verbosity=2,
|
||||||
|
task=task,
|
||||||
|
target=target,
|
||||||
|
)
|
||||||
|
train_autogluon(dataset_ensemble, settings, experiment=EXPERIMENT_NAME)
|
||||||
|
|
||||||
|
# HPOCV
|
||||||
|
splitter = "stratified_shuffle" if task == "binary" else "kfold"
|
||||||
|
models: list[Model] = ["xgboost", "rf", "knn"]
|
||||||
|
if task == "binary":
|
||||||
|
models.append("espa")
|
||||||
|
for model in models:
|
||||||
|
print(f"\nRunning HPOCV for model {model}...")
|
||||||
|
n_iter = {
|
||||||
|
"espa": 300,
|
||||||
|
"xgboost": 100,
|
||||||
|
"rf": 40,
|
||||||
|
"knn": 20,
|
||||||
|
}[model]
|
||||||
|
# n_iter = 3
|
||||||
|
scaler = "standard" if model in ["espa", "knn"] else "none"
|
||||||
|
normalize = scaler != "none"
|
||||||
|
settings = HPOCVRunSettings(
|
||||||
|
n_iter=n_iter,
|
||||||
|
task=task,
|
||||||
|
target=target,
|
||||||
|
splitter=splitter,
|
||||||
|
model=model,
|
||||||
|
scaler=scaler,
|
||||||
|
normalize=normalize,
|
||||||
|
)
|
||||||
|
hpsearch_cv(dataset_ensemble, settings, experiment=EXPERIMENT_NAME)
|
||||||
|
|
||||||
|
stopwatch.summary()
|
||||||
|
times = stopwatch.export()
|
||||||
|
times.to_parquet(RESULTS_DIR / EXPERIMENT_NAME / f"training_times_{target}_{grid}.parquet")
|
||||||
|
print("Done.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
|
|
@ -27,6 +27,7 @@ pretty.install()
|
||||||
|
|
||||||
darts_v1_l2_file = DARTS_V1_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
|
darts_v1_l2_file = DARTS_V1_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
|
||||||
darts_v1_l2_cov_file = DARTS_V1_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
|
darts_v1_l2_cov_file = DARTS_V1_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
|
||||||
|
darts_v1_corrections = DARTS_V1_DIR / "negative_correction.geojson"
|
||||||
darts_ml_training_labels_repo = DARTS_MLLABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
|
darts_ml_training_labels_repo = DARTS_MLLABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
|
||||||
|
|
||||||
cli = cyclopts.App(name="darts-rts")
|
cli = cyclopts.App(name="darts-rts")
|
||||||
|
|
@ -153,6 +154,22 @@ def extract_darts_v1(grid: Grid, level: int):
|
||||||
darts_l2 = gpd.read_parquet(darts_v1_l2_file)
|
darts_l2 = gpd.read_parquet(darts_v1_l2_file)
|
||||||
darts_cov_l2 = gpd.read_parquet(darts_v1_l2_cov_file)
|
darts_cov_l2 = gpd.read_parquet(darts_v1_l2_cov_file)
|
||||||
grid_gdf, cell_areas = _load_grid(grid, level)
|
grid_gdf, cell_areas = _load_grid(grid, level)
|
||||||
|
corrections = gpd.read_file(darts_v1_corrections).to_crs(darts_l2.crs)
|
||||||
|
|
||||||
|
with stopwatch("Apply corrections"):
|
||||||
|
# The correction file is just an area of sure negatives
|
||||||
|
# Thus, we first need to remove all RTS labels that intersect with the correction area,
|
||||||
|
darts_l2 = gpd.overlay(darts_l2, corrections, how="difference")
|
||||||
|
# then we need to add the correction area as coverage to the coverage file per year.
|
||||||
|
darts_cov_l2_cor = []
|
||||||
|
for year in darts_cov_l2["year"].unique():
|
||||||
|
year_cov = darts_cov_l2[darts_cov_l2["year"] == year]
|
||||||
|
year_cov_cor = gpd.overlay(year_cov, corrections, how="union")
|
||||||
|
year_cov_cor["year"] = year
|
||||||
|
darts_cov_l2_cor.append(year_cov_cor)
|
||||||
|
darts_cov_l2_cor = gpd.GeoDataFrame(pd.concat(darts_cov_l2_cor, ignore_index=True))
|
||||||
|
darts_cov_l2_cor["year"] = darts_cov_l2_cor["year"].astype(int)
|
||||||
|
darts_cov_l2 = darts_cov_l2_cor
|
||||||
|
|
||||||
with stopwatch("Assign RTS to grid"):
|
with stopwatch("Assign RTS to grid"):
|
||||||
grid_l2 = grid_gdf.overlay(darts_l2.to_crs(grid_gdf.crs), how="intersection")
|
grid_l2 = grid_gdf.overlay(darts_l2.to_crs(grid_gdf.crs), how="intersection")
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,6 @@ This package contains modules for machine learning workflows:
|
||||||
- inference: Batch prediction pipeline for trained classifiers
|
- inference: Batch prediction pipeline for trained classifiers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from . import dataset, inference, training
|
from . import dataset, inference, randomsearch
|
||||||
|
|
||||||
__all__ = ["dataset", "inference", "training"]
|
__all__ = ["dataset", "inference", "randomsearch"]
|
||||||
|
|
|
||||||
191
src/entropice/ml/autogluon.py
Normal file
191
src/entropice/ml/autogluon.py
Normal file
|
|
@ -0,0 +1,191 @@
|
||||||
|
"""Training of models with AutoGluon."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import cyclopts
|
||||||
|
import pandas as pd
|
||||||
|
import shap.maskers
|
||||||
|
import xarray as xr
|
||||||
|
from autogluon.tabular import TabularDataset, TabularPredictor
|
||||||
|
from rich import pretty, traceback
|
||||||
|
from shap import Explainer, Explanation
|
||||||
|
from sklearn import set_config
|
||||||
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
|
from entropice.ml.inference import predict_proba
|
||||||
|
from entropice.utils.paths import get_training_results_dir
|
||||||
|
from entropice.utils.training import AutoML, Training
|
||||||
|
from entropice.utils.types import TargetDataset, Task
|
||||||
|
|
||||||
|
traceback.install()
|
||||||
|
pretty.install()
|
||||||
|
|
||||||
|
|
||||||
|
cli = cyclopts.App(
|
||||||
|
"entropice-autogluon",
|
||||||
|
config=cyclopts.config.Toml("autogluon-config.toml", root_keys=["tool", "entropice-autogluon"]), # ty:ignore[invalid-argument-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cyclopts.Parameter("*")
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class RunSettings:
|
||||||
|
"""Run settings for training."""
|
||||||
|
|
||||||
|
time_limit: int = 3600 # Time limit in seconds (1 hour default)
|
||||||
|
presets: str = "extreme"
|
||||||
|
task: Task = "binary"
|
||||||
|
target: TargetDataset = "darts_v1"
|
||||||
|
verbosity: int = 2 # Verbosity level (0-4)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_metrics_and_confusion_matrix( # noqa: C901
|
||||||
|
predictor: TabularPredictor, train_data: pd.DataFrame, test_data: pd.DataFrame, complete_data: pd.DataFrame
|
||||||
|
) -> tuple[pd.DataFrame, xr.Dataset | None]:
|
||||||
|
train_scores = predictor.evaluate(train_data, display=True, detailed_report=True)
|
||||||
|
test_scores = predictor.evaluate(test_data, display=True, detailed_report=True)
|
||||||
|
complete_scores = predictor.evaluate(complete_data, display=True, detailed_report=True)
|
||||||
|
m = []
|
||||||
|
cm = {}
|
||||||
|
for dataset, scores in zip(["train", "test", "complete"], [train_scores, test_scores, complete_scores]):
|
||||||
|
for metric, score in scores.items():
|
||||||
|
if metric == "confusion_matrix":
|
||||||
|
score = cast(pd.DataFrame, score)
|
||||||
|
confusion_matrix = xr.DataArray(
|
||||||
|
score.to_numpy(),
|
||||||
|
dims=("y_true", "y_pred"),
|
||||||
|
coords={"y_true": score.index.tolist(), "y_pred": score.columns.tolist()},
|
||||||
|
)
|
||||||
|
cm[dataset] = confusion_matrix
|
||||||
|
elif metric == "classification_report":
|
||||||
|
score = cast(dict[str, dict[str, float]], score)
|
||||||
|
score.pop("accuracy") # Accuracy is already included as a separate metric
|
||||||
|
macro_avg = score.pop("macro avg")
|
||||||
|
for macro_avg_metric, macro_avg_score in macro_avg.items():
|
||||||
|
metric_name = f"macro_avg_{macro_avg_metric}"
|
||||||
|
m.append({"dataset": dataset, "metric": metric_name, "score": macro_avg_score})
|
||||||
|
weighted_avg = score.pop("weighted avg")
|
||||||
|
for weighted_avg_metric, weighted_avg_score in weighted_avg.items():
|
||||||
|
metric_name = f"weighted_avg_{weighted_avg_metric}"
|
||||||
|
m.append({"dataset": dataset, "metric": metric_name, "score": weighted_avg_score})
|
||||||
|
for class_name, class_scores in score.items():
|
||||||
|
class_name = class_name.replace(" ", "-")
|
||||||
|
for class_metric, class_score in class_scores.items():
|
||||||
|
m.append({"dataset": dataset, "metric": f"{class_name}_{class_metric}", "score": class_score})
|
||||||
|
else: # Scalar metric
|
||||||
|
m.append({"dataset": dataset, "metric": metric, "score": score})
|
||||||
|
if len(cm) == 0:
|
||||||
|
return pd.DataFrame(m), None
|
||||||
|
elif len(cm) == 3:
|
||||||
|
return pd.DataFrame(m), xr.Dataset(cm)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Confusion matrix should be computed for all datasets or none.")
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_shap_explanation(
|
||||||
|
predictor: TabularPredictor,
|
||||||
|
train_data: pd.DataFrame,
|
||||||
|
test_data: pd.DataFrame,
|
||||||
|
feature_names: list[str],
|
||||||
|
target_labels: list[str] | None,
|
||||||
|
) -> Explanation:
|
||||||
|
masker = shap.maskers.Independent(data=train_data.drop(columns=["label"]))
|
||||||
|
explainer = Explainer(
|
||||||
|
predictor.predict_proba if predictor.problem_type in ["binary", "multiclass"] else predictor.predict,
|
||||||
|
masker=masker,
|
||||||
|
seed=42,
|
||||||
|
feature_names=feature_names,
|
||||||
|
output_names=target_labels,
|
||||||
|
)
|
||||||
|
samples = test_data.drop(columns=["label"])
|
||||||
|
if len(samples) > 200:
|
||||||
|
samples = samples.sample(n=200, random_state=42)
|
||||||
|
explanation = explainer(samples)
|
||||||
|
return explanation
|
||||||
|
|
||||||
|
|
||||||
|
@cli.default
|
||||||
|
def train(
|
||||||
|
dataset_ensemble: DatasetEnsemble,
|
||||||
|
settings: RunSettings = RunSettings(),
|
||||||
|
experiment: str | None = None,
|
||||||
|
) -> Training:
|
||||||
|
"""Perform random cross-validation on the training dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_ensemble (DatasetEnsemble): The dataset ensemble configuration.
|
||||||
|
settings (RunSettings): This runs settings.
|
||||||
|
experiment (str | None): Optional experiment name for results directory.
|
||||||
|
|
||||||
|
"""
|
||||||
|
set_config(array_api_dispatch=False)
|
||||||
|
results_dir = get_training_results_dir(
|
||||||
|
experiment=experiment,
|
||||||
|
name="autogluon",
|
||||||
|
grid=dataset_ensemble.grid,
|
||||||
|
level=dataset_ensemble.level,
|
||||||
|
task=settings.task,
|
||||||
|
target=settings.target,
|
||||||
|
)
|
||||||
|
print(f"\n💾 Results directory: {results_dir}")
|
||||||
|
|
||||||
|
print("Creating training data...")
|
||||||
|
training_data = dataset_ensemble.create_training_set(task=settings.task, target=settings.target)
|
||||||
|
# Convert to AutoGluon TabularDataset
|
||||||
|
train_data: pd.DataFrame = TabularDataset(training_data.to_dataframe("train")) # ty:ignore[invalid-assignment]
|
||||||
|
test_data: pd.DataFrame = TabularDataset(training_data.to_dataframe("test")) # ty:ignore[invalid-assignment]
|
||||||
|
complete_data: pd.DataFrame = TabularDataset(training_data.to_dataframe(None)) # ty:ignore[invalid-assignment]
|
||||||
|
|
||||||
|
# Initialize TabularPredictor
|
||||||
|
print(f"\n🚀 Initializing AutoGluon TabularPredictor (preset='{settings.presets}')...")
|
||||||
|
predictor = TabularPredictor(
|
||||||
|
label="label",
|
||||||
|
path=str(results_dir / "models"),
|
||||||
|
verbosity=settings.verbosity,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train models
|
||||||
|
print(f"\n⚡ Training models for {settings.time_limit / 60}min...")
|
||||||
|
with stopwatch("AutoGluon training"):
|
||||||
|
predictor.fit(
|
||||||
|
train_data=train_data,
|
||||||
|
time_limit=settings.time_limit,
|
||||||
|
presets=settings.presets,
|
||||||
|
num_gpus=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n📊 Evaluating model performance...")
|
||||||
|
leaderboard = predictor.leaderboard(test_data)
|
||||||
|
feature_importance = predictor.feature_importance(test_data)
|
||||||
|
metrics, confusion_matrix = _compute_metrics_and_confusion_matrix(predictor, train_data, test_data, complete_data)
|
||||||
|
|
||||||
|
with stopwatch("Explaining model predictions with SHAP..."):
|
||||||
|
explanation = _compute_shap_explanation(
|
||||||
|
predictor, train_data, test_data, training_data.feature_names, training_data.target_labels
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Predicting probabilities for all cells...")
|
||||||
|
preds = predict_proba(dataset_ensemble, model=predictor, task=settings.task)
|
||||||
|
print(f"Predicted probabilities DataFrame with {len(preds)} entries.")
|
||||||
|
|
||||||
|
summary = Training(
|
||||||
|
path=results_dir,
|
||||||
|
dataset=dataset_ensemble,
|
||||||
|
method=AutoML(time_budget=settings.time_limit, preset=settings.presets, hpo=False),
|
||||||
|
task=settings.task,
|
||||||
|
target=settings.target,
|
||||||
|
training_set=training_data,
|
||||||
|
model=predictor,
|
||||||
|
model_type="autogluon",
|
||||||
|
metrics=metrics,
|
||||||
|
feature_importance=feature_importance,
|
||||||
|
shap_explanation=explanation,
|
||||||
|
predictions=preds,
|
||||||
|
confusion_matrix=confusion_matrix,
|
||||||
|
cv_results=None,
|
||||||
|
leaderboard=leaderboard,
|
||||||
|
)
|
||||||
|
summary.save()
|
||||||
|
return summary
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""Training with AutoGluon TabularPredictor for automated ML."""
|
"""DePRECATED!!! Training with AutoGluon TabularPredictor for automated ML."""
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
|
|
@ -12,7 +12,7 @@ from sklearn import set_config
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
from entropice.utils.paths import get_autogluon_results_dir
|
from entropice.utils.paths import get_training_results_dir
|
||||||
from entropice.utils.types import TargetDataset, Task
|
from entropice.utils.types import TargetDataset, Task
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
|
|
@ -101,11 +101,13 @@ def autogluon_train(
|
||||||
print(f"📈 Evaluation metric: {eval_metric}")
|
print(f"📈 Evaluation metric: {eval_metric}")
|
||||||
|
|
||||||
# Create results directory
|
# Create results directory
|
||||||
results_dir = get_autogluon_results_dir(
|
results_dir = get_training_results_dir(
|
||||||
experiment=experiment,
|
experiment=experiment,
|
||||||
grid=dataset_ensemble.grid,
|
grid=dataset_ensemble.grid,
|
||||||
level=dataset_ensemble.level,
|
level=dataset_ensemble.level,
|
||||||
task=settings.task,
|
task=settings.task,
|
||||||
|
target=settings.target,
|
||||||
|
name="autogluon",
|
||||||
)
|
)
|
||||||
print(f"\n💾 Results directory: {results_dir}")
|
print(f"\n💾 Results directory: {results_dir}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from collections.abc import Generator
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from functools import cache, cached_property
|
from functools import cache, cached_property
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Literal, cast
|
from typing import Literal, TypeVar, cast
|
||||||
|
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
import cyclopts
|
import cyclopts
|
||||||
|
|
@ -28,6 +28,7 @@ import pandas as pd
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from cuml import KMeans
|
||||||
from rich import pretty, traceback
|
from rich import pretty, traceback
|
||||||
from sklearn import set_config
|
from sklearn import set_config
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
@ -118,26 +119,29 @@ def bin_values(
|
||||||
return binned
|
return binned
|
||||||
|
|
||||||
|
|
||||||
|
ArrayType = TypeVar("ArrayType", torch.Tensor, np.ndarray, cp.ndarray)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, eq=False)
|
@dataclass(frozen=True, eq=False)
|
||||||
class SplittedArrays:
|
class SplittedArrays[ArrayType: (torch.Tensor, np.ndarray, cp.ndarray)]:
|
||||||
"""Small wrapper for train and test arrays."""
|
"""Small wrapper for train and test arrays."""
|
||||||
|
|
||||||
train: torch.Tensor | np.ndarray | cp.ndarray
|
train: ArrayType
|
||||||
test: torch.Tensor | np.ndarray | cp.ndarray
|
test: ArrayType
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def combined(self) -> torch.Tensor | np.ndarray | cp.ndarray:
|
def combined(self) -> ArrayType:
|
||||||
"""Combined train and test arrays."""
|
"""Combined train and test arrays."""
|
||||||
if isinstance(self.train, torch.Tensor) and isinstance(self.test, torch.Tensor):
|
if isinstance(self.train, torch.Tensor) and isinstance(self.test, torch.Tensor):
|
||||||
return torch.cat([self.train, self.test], dim=0)
|
return torch.cat([self.train, self.test], dim=0) # ty:ignore[invalid-return-type]
|
||||||
elif isinstance(self.train, cp.ndarray) and isinstance(self.test, cp.ndarray):
|
elif isinstance(self.train, cp.ndarray) and isinstance(self.test, cp.ndarray):
|
||||||
return cp.concatenate([self.train, self.test], axis=0)
|
return cp.concatenate([self.train, self.test], axis=0)
|
||||||
elif isinstance(self.train, np.ndarray) and isinstance(self.test, np.ndarray):
|
elif isinstance(self.train, np.ndarray) and isinstance(self.test, np.ndarray):
|
||||||
return np.concatenate([self.train, self.test], axis=0)
|
return np.concatenate([self.train, self.test], axis=0) # ty:ignore[invalid-return-type]
|
||||||
else:
|
else:
|
||||||
raise TypeError("Incompatible types for train and test arrays.")
|
raise TypeError("Incompatible types for train and test arrays.")
|
||||||
|
|
||||||
def as_numpy(self) -> "SplittedArrays":
|
def as_numpy(self) -> "SplittedArrays[np.ndarray]":
|
||||||
"""Return the arrays as numpy arrays."""
|
"""Return the arrays as numpy arrays."""
|
||||||
train_np = (
|
train_np = (
|
||||||
self.train.cpu().numpy()
|
self.train.cpu().numpy()
|
||||||
|
|
@ -157,13 +161,13 @@ class SplittedArrays:
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, eq=False)
|
@dataclass(frozen=True, eq=False)
|
||||||
class TrainingSet:
|
class TrainingSet[ArrayType: (torch.Tensor, np.ndarray, cp.ndarray)]:
|
||||||
"""Container for the training dataset."""
|
"""Container for the training dataset."""
|
||||||
|
|
||||||
targets: gpd.GeoDataFrame
|
targets: gpd.GeoDataFrame
|
||||||
features: pd.DataFrame
|
features: pd.DataFrame
|
||||||
X: SplittedArrays
|
X: SplittedArrays[ArrayType]
|
||||||
y: SplittedArrays
|
y: SplittedArrays[ArrayType]
|
||||||
z: pd.Series
|
z: pd.Series
|
||||||
split: pd.Series
|
split: pd.Series
|
||||||
|
|
||||||
|
|
@ -214,6 +218,16 @@ class TrainingSet:
|
||||||
"""Names of the features."""
|
"""Names of the features."""
|
||||||
return self.features.columns.tolist()
|
return self.features.columns.tolist()
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def clusters(self) -> pd.Series:
|
||||||
|
"""Geo-Cluster assignments for each sample, based on the geometries of the target GeoDataFrame."""
|
||||||
|
centroids = self.targets.to_crs("EPSG:3413")["geometry"].centroid
|
||||||
|
centroids = cp.array([centroids.x.to_numpy(), centroids.y.to_numpy()]).T
|
||||||
|
# Use kMeans to cluster the centroids into 10 clusters
|
||||||
|
kmeans = KMeans(n_clusters=10, random_state=42)
|
||||||
|
clusters = kmeans.fit_predict(centroids).get()
|
||||||
|
return pd.Series(clusters, index=self.targets.index)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.z)
|
return len(self.z)
|
||||||
|
|
||||||
|
|
|
||||||
0
src/entropice/ml/explainations.py
Normal file
0
src/entropice/ml/explainations.py
Normal file
395
src/entropice/ml/hpsearchcv.py
Normal file
395
src/entropice/ml/hpsearchcv.py
Normal file
|
|
@ -0,0 +1,395 @@
|
||||||
|
"""Training of models with hyperparameter search with cross-validation."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal, cast
|
||||||
|
|
||||||
|
import cupy as cp
|
||||||
|
import cyclopts
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import shap.maskers
|
||||||
|
import torch
|
||||||
|
import xarray as xr
|
||||||
|
from rich import pretty, traceback
|
||||||
|
from shap import Explainer, Explanation, TreeExplainer
|
||||||
|
from sklearn import set_config
|
||||||
|
from sklearn.inspection import permutation_importance
|
||||||
|
from sklearn.metrics import confusion_matrix
|
||||||
|
from sklearn.model_selection import RandomizedSearchCV
|
||||||
|
from sklearn.pipeline import Pipeline
|
||||||
|
from sklearn.preprocessing import (
|
||||||
|
FunctionTransformer,
|
||||||
|
MaxAbsScaler,
|
||||||
|
MinMaxScaler,
|
||||||
|
Normalizer,
|
||||||
|
QuantileTransformer,
|
||||||
|
StandardScaler,
|
||||||
|
)
|
||||||
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
|
from entropice.ml.dataset import DatasetEnsemble, SplittedArrays, TrainingSet
|
||||||
|
from entropice.ml.inference import predict_proba
|
||||||
|
from entropice.ml.models import (
|
||||||
|
ModelHPOConfig,
|
||||||
|
extract_espa_feature_importance,
|
||||||
|
extract_rf_feature_importance,
|
||||||
|
extract_xgboost_feature_importance,
|
||||||
|
get_model_hpo_config,
|
||||||
|
get_splitter,
|
||||||
|
)
|
||||||
|
from entropice.utils.metrics import get_metrics, metric_functions
|
||||||
|
from entropice.utils.paths import get_training_results_dir
|
||||||
|
from entropice.utils.training import HPOCV, Training, move_data_to_device
|
||||||
|
from entropice.utils.types import HPSearch, Model, Scaler, Splitter, TargetDataset, Task
|
||||||
|
|
||||||
|
traceback.install()
|
||||||
|
pretty.install()
|
||||||
|
|
||||||
|
|
||||||
|
cli = cyclopts.App(
|
||||||
|
"entropice-hpsearchcv",
|
||||||
|
config=cyclopts.config.Toml("hpsearchcv-config.toml", root_keys=["tool", "entropice-hpsearchcv"]), # ty:ignore[invalid-argument-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cyclopts.Parameter("*")
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class RunSettings:
|
||||||
|
"""Run settings for training."""
|
||||||
|
|
||||||
|
n_iter: int = 2000
|
||||||
|
task: Task = "binary"
|
||||||
|
target: TargetDataset = "darts_v1"
|
||||||
|
search: HPSearch = "random" # TODO: Implement grid and bayesian search
|
||||||
|
splitter: Splitter = "stratified_shuffle"
|
||||||
|
model: Model = "espa"
|
||||||
|
scaler: Scaler = "standard"
|
||||||
|
normalize: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Literal["torch", "cuda", "cpu"]:
|
||||||
|
"""Get the device to use for model training.
|
||||||
|
|
||||||
|
Note: Device Management currently is super nasty:
|
||||||
|
- CuML (RF & kNN) expects cupy ("cuda") arrays
|
||||||
|
- XGBoost always returns CPU arrays
|
||||||
|
- eSPA for some reason is super slow on cupy but very fast on torch tensors
|
||||||
|
- The sklearn permutation stuff does not work with torch tensors, only cupy and numpy arrays
|
||||||
|
- SHAP in general only works with CPU
|
||||||
|
"""
|
||||||
|
return "torch" if self.model == "espa" else "cuda"
|
||||||
|
|
||||||
|
def build_pipeline(self, model_hpo_config: ModelHPOConfig) -> Pipeline: # noqa: C901
|
||||||
|
"""Build a scikit-learn Pipeline based on the settings."""
|
||||||
|
# Add a feature scaler / normalization step if specified, but assert that it's only used for non-Tree models
|
||||||
|
if self.model in ["rf", "xgboost"]:
|
||||||
|
assert self.scaler == "none", f"Scaler {self.scaler} is not viable with model {self.model}"
|
||||||
|
elif self.scaler == "none":
|
||||||
|
assert self.scaler != "none", f"No scaler specified for model {self.model}, which is not viable."
|
||||||
|
|
||||||
|
match self.scaler:
|
||||||
|
case "standard":
|
||||||
|
scaler = StandardScaler()
|
||||||
|
case "minmax":
|
||||||
|
scaler = MinMaxScaler()
|
||||||
|
case "maxabs":
|
||||||
|
scaler = MaxAbsScaler()
|
||||||
|
case "quantile":
|
||||||
|
scaler = QuantileTransformer(output_distribution="normal")
|
||||||
|
case "none":
|
||||||
|
scaler = None
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown scaler: {self.scaler}")
|
||||||
|
|
||||||
|
pipeline_steps = []
|
||||||
|
if scaler is not None:
|
||||||
|
print(f"Using {scaler.__class__.__name__} for feature scaling.")
|
||||||
|
pipeline_steps.append(("scaler", scaler))
|
||||||
|
if self.normalize:
|
||||||
|
print("Using Normalizer for feature normalization.")
|
||||||
|
pipeline_steps.append(("normalizer", Normalizer()))
|
||||||
|
# Necessary, because scaler and normalizer are only GPU ready in 1.8.0 but autogluon requries <1.8.0
|
||||||
|
if len(pipeline_steps) > 0 and self.device != "cpu":
|
||||||
|
print(f"Adding steps to move data to CPU for preprocessing and back to {self.device} for model fitting.")
|
||||||
|
to_cpu = FunctionTransformer(move_data_to_device, kw_args={"device": "cpu"})
|
||||||
|
pipeline_steps.insert(0, ("to_cpu", to_cpu))
|
||||||
|
to_gpu = FunctionTransformer(move_data_to_device, kw_args={"device": self.device})
|
||||||
|
pipeline_steps.append(("to_device", to_gpu))
|
||||||
|
pipeline_steps.append(("model", model_hpo_config.model))
|
||||||
|
return Pipeline(pipeline_steps)
|
||||||
|
|
||||||
|
def build_search(
|
||||||
|
self, pipeline: Pipeline, model_hpo_config: ModelHPOConfig, metrics: list[str], refit: str
|
||||||
|
) -> RandomizedSearchCV:
|
||||||
|
"""Build a scikit-learn RandomizedSearchCV based on the settings."""
|
||||||
|
if self.task in ["density", "count"]:
|
||||||
|
assert self.splitter not in ["stratified_shuffle", "stratified_kfold"], (
|
||||||
|
f"Splitter {self.splitter} is not viable for regression tasks"
|
||||||
|
)
|
||||||
|
|
||||||
|
cv = get_splitter(self.splitter, n_splits=5)
|
||||||
|
print(f"Using {cv.__class__.__name__} for cross-validation splitting.")
|
||||||
|
|
||||||
|
search_space = {f"model__{k}": v for k, v in model_hpo_config.search_space.items()}
|
||||||
|
if self.search == "random":
|
||||||
|
hp_search = RandomizedSearchCV(
|
||||||
|
estimator=pipeline,
|
||||||
|
param_distributions=search_space,
|
||||||
|
n_iter=self.n_iter,
|
||||||
|
cv=cv,
|
||||||
|
scoring=metrics,
|
||||||
|
refit=refit,
|
||||||
|
random_state=42,
|
||||||
|
verbose=2,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Search method {self.search} not implemented yet.")
|
||||||
|
return hp_search
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_cv_results(cv_results) -> pd.DataFrame:
|
||||||
|
cv_results = pd.DataFrame(cv_results)
|
||||||
|
# Parse the params into individual columns
|
||||||
|
params = pd.json_normalize(cv_results["params"]) # ty:ignore[invalid-argument-type]
|
||||||
|
cv_results = pd.concat([cv_results.drop(columns=["params"]), params], axis=1)
|
||||||
|
return cv_results
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_metrics(y: SplittedArrays, y_pred: SplittedArrays, metrics: list[str]) -> pd.DataFrame:
|
||||||
|
m = []
|
||||||
|
for metric in metrics:
|
||||||
|
metric_fn = metric_functions[metric]
|
||||||
|
for split in ["train", "test", "combined"]:
|
||||||
|
value = metric_fn(getattr(y, split), getattr(y_pred, split))
|
||||||
|
m.append({"metric": metric, "split": split, "value": value})
|
||||||
|
return pd.DataFrame(m)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_confusion_matrices(
|
||||||
|
y: SplittedArrays, y_pred: SplittedArrays, codes: np.ndarray, labels: list[str] | None
|
||||||
|
) -> xr.Dataset:
|
||||||
|
if labels is None:
|
||||||
|
labels = [str(code) for code in codes]
|
||||||
|
cm = xr.Dataset(
|
||||||
|
{
|
||||||
|
"test": (("true_label", "predicted_label"), confusion_matrix(y.test, y_pred.test, labels=codes)),
|
||||||
|
"train": (("true_label", "predicted_label"), confusion_matrix(y.train, y_pred.train, labels=codes)),
|
||||||
|
"combined": (
|
||||||
|
("true_label", "predicted_label"),
|
||||||
|
confusion_matrix(y.combined, y_pred.combined, labels=codes),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
coords={"true_label": labels, "predicted_label": labels},
|
||||||
|
)
|
||||||
|
return cm
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_feature_importance(model: Model, best_estimator: Pipeline, training_data: TrainingSet) -> pd.DataFrame:
|
||||||
|
X_test = training_data.X.as_numpy().test if model in ["xgboost", "espa"] else training_data.X.test # noqa: N806
|
||||||
|
y_test = training_data.y.as_numpy().test if model in ["xgboost", "espa"] else training_data.y.test
|
||||||
|
if model == "espa":
|
||||||
|
# eSPA is super slow with cupy arrays, so we need to move the data to CPU for permutation importance
|
||||||
|
# To to this, we manually recreate the pipeline but without the device steps
|
||||||
|
if "scaler" in best_estimator.named_steps:
|
||||||
|
X_test = best_estimator.named_steps["scaler"].transform(X_test) # noqa: N806
|
||||||
|
if "normalizer" in best_estimator.named_steps:
|
||||||
|
X_test = best_estimator.named_steps["normalizer"].transform(X_test) # noqa: N806
|
||||||
|
best_estimator.named_steps["model"].to_numpy() # inplace
|
||||||
|
|
||||||
|
r = permutation_importance(
|
||||||
|
best_estimator if model != "espa" else best_estimator.named_steps["model"],
|
||||||
|
X_test,
|
||||||
|
y_test,
|
||||||
|
n_repeats=5,
|
||||||
|
random_state=0,
|
||||||
|
max_samples=min(5000, training_data.X.test.shape[0]),
|
||||||
|
)
|
||||||
|
|
||||||
|
if model == "espa":
|
||||||
|
best_estimator.named_steps["model"].to_torch() # inplace
|
||||||
|
|
||||||
|
feature_importances = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"importance": r["importances_mean"],
|
||||||
|
"stddev": r["importances_std"],
|
||||||
|
},
|
||||||
|
index=training_data.feature_names,
|
||||||
|
)
|
||||||
|
match model:
|
||||||
|
case "espa":
|
||||||
|
model_feature_importances = extract_espa_feature_importance(
|
||||||
|
best_estimator.named_steps["model"], training_data
|
||||||
|
).rename(columns={"importance": "model_feature_weights"})
|
||||||
|
case "xgboost":
|
||||||
|
model_feature_importances = extract_xgboost_feature_importance(
|
||||||
|
best_estimator.named_steps["model"], training_data
|
||||||
|
)
|
||||||
|
case "rf":
|
||||||
|
model_feature_importances = extract_rf_feature_importance(
|
||||||
|
best_estimator.named_steps["model"], training_data
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
model_feature_importances = None
|
||||||
|
if model_feature_importances is not None:
|
||||||
|
feature_importances = feature_importances.join(model_feature_importances)
|
||||||
|
return feature_importances
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_shap_explanation(model: Model, best_estimator: Pipeline, training_data: TrainingSet) -> Explanation:
|
||||||
|
match model:
|
||||||
|
case "espa" | "knn" | "rf": # CUML models do not yet work with TreeExplainer...
|
||||||
|
train_transformed = training_data.X.as_numpy().train
|
||||||
|
if "scaler" in best_estimator.named_steps:
|
||||||
|
train_transformed = best_estimator.named_steps["scaler"].transform(train_transformed)
|
||||||
|
if "normalizer" in best_estimator.named_steps:
|
||||||
|
train_transformed = best_estimator.named_steps["normalizer"].transform(train_transformed)
|
||||||
|
masker = shap.maskers.Independent(data=train_transformed)
|
||||||
|
|
||||||
|
def _model_predict(data):
|
||||||
|
mdl = best_estimator.named_steps["model"]
|
||||||
|
f = mdl.predict_proba if hasattr(mdl, "predict_proba") else mdl.predict
|
||||||
|
|
||||||
|
out = f(data)
|
||||||
|
if isinstance(out, torch.Tensor):
|
||||||
|
out = out.cpu().numpy()
|
||||||
|
elif isinstance(out, cp.ndarray):
|
||||||
|
out = out.get()
|
||||||
|
return out
|
||||||
|
|
||||||
|
explainer = Explainer(
|
||||||
|
_model_predict,
|
||||||
|
masker=masker,
|
||||||
|
seed=42,
|
||||||
|
feature_names=training_data.feature_names,
|
||||||
|
output_names=training_data.target_labels,
|
||||||
|
)
|
||||||
|
case "xgboost":
|
||||||
|
explainer = TreeExplainer(best_estimator.named_steps["model"], feature_names=training_data.feature_names)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown model: {model}")
|
||||||
|
|
||||||
|
samples = training_data.X.as_numpy().test
|
||||||
|
if len(samples) > 200:
|
||||||
|
rng = np.random.default_rng(seed=42)
|
||||||
|
sample_indices = rng.choice(len(samples), size=200, replace=False)
|
||||||
|
samples = samples[sample_indices]
|
||||||
|
if "scaler" in best_estimator.named_steps:
|
||||||
|
samples = best_estimator.named_steps["scaler"].transform(samples)
|
||||||
|
if "normalizer" in best_estimator.named_steps:
|
||||||
|
samples = best_estimator.named_steps["normalizer"].transform(samples)
|
||||||
|
explanation = explainer(samples)
|
||||||
|
return explanation
|
||||||
|
|
||||||
|
|
||||||
|
@cli.default
|
||||||
|
def hpsearch_cv(
|
||||||
|
dataset_ensemble: DatasetEnsemble,
|
||||||
|
settings: RunSettings = RunSettings(),
|
||||||
|
experiment: str | None = None,
|
||||||
|
) -> Training:
|
||||||
|
"""Perform random cross-validation on the training dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_ensemble (DatasetEnsemble): The dataset ensemble configuration.
|
||||||
|
settings (RunSettings): This runs settings.
|
||||||
|
experiment (str | None): Optional experiment name for results directory.
|
||||||
|
|
||||||
|
"""
|
||||||
|
results_dir = get_training_results_dir(
|
||||||
|
experiment=experiment,
|
||||||
|
name="random_search",
|
||||||
|
grid=dataset_ensemble.grid,
|
||||||
|
level=dataset_ensemble.level,
|
||||||
|
task=settings.task,
|
||||||
|
target=settings.target,
|
||||||
|
model_type=settings.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Since we use cuml and xgboost libraries, we can only enable array API for ESPA
|
||||||
|
use_array_api = settings.model != "xgboost"
|
||||||
|
set_config(array_api_dispatch=use_array_api)
|
||||||
|
|
||||||
|
print("Creating training data...")
|
||||||
|
training_data = dataset_ensemble.create_training_set(
|
||||||
|
task=settings.task, target=settings.target, device=settings.device
|
||||||
|
)
|
||||||
|
|
||||||
|
model_hpo_config = get_model_hpo_config(settings.model, settings.task)
|
||||||
|
print(f"Using model: {settings.model} with parameters: {model_hpo_config.hp_config}")
|
||||||
|
|
||||||
|
metrics, refit = get_metrics(settings.task)
|
||||||
|
print(f"Using {len(metrics)} metrics as scoring and {refit} for refitting.")
|
||||||
|
|
||||||
|
pipeline = settings.build_pipeline(model_hpo_config)
|
||||||
|
print(f"Pipeline steps: {pipeline.named_steps}")
|
||||||
|
|
||||||
|
hp_search = settings.build_search(pipeline, model_hpo_config, metrics, refit)
|
||||||
|
print(f"Starting hyperparameter search with {settings.n_iter} iterations...")
|
||||||
|
with stopwatch(f"RandomizedSearchCV fitting for {settings.n_iter} candidates"):
|
||||||
|
fit_params = {f"model__{k}": v for k, v in model_hpo_config.fit_params.items()}
|
||||||
|
hp_search.fit(
|
||||||
|
training_data.X.train,
|
||||||
|
# XGBoost returns it's labels as numpy arrays instead of cupy arrays
|
||||||
|
# Thus, for the scoring to work, we need to convert them back to numpy
|
||||||
|
training_data.y.as_numpy().train if settings.model == "xgboost" else training_data.y.train,
|
||||||
|
**fit_params,
|
||||||
|
)
|
||||||
|
print("Best parameters combination found:")
|
||||||
|
best_estimator = cast(Pipeline, hp_search.best_estimator_)
|
||||||
|
best_parameters = best_estimator.get_params()
|
||||||
|
for param_name in sorted(model_hpo_config.hp_config.keys()):
|
||||||
|
search_param_name = f"model__{param_name}"
|
||||||
|
print(f"{param_name}: {best_parameters[search_param_name]}")
|
||||||
|
|
||||||
|
# Compute predictions on the all sets and move them to numpy for metric computations
|
||||||
|
y_pred = SplittedArrays(
|
||||||
|
train=best_estimator.predict(training_data.X.train),
|
||||||
|
test=best_estimator.predict(training_data.X.test),
|
||||||
|
).as_numpy()
|
||||||
|
y = training_data.y.as_numpy()
|
||||||
|
metrics = _compute_metrics(y, y_pred, metrics)
|
||||||
|
if settings.task in ["binary", "count_regimes", "density_regimes"]:
|
||||||
|
confusion_matrix = _compute_confusion_matrices(
|
||||||
|
y, y_pred, codes=np.array(training_data.target_codes), labels=training_data.target_labels
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
confusion_matrix = None
|
||||||
|
print("Metrics computed!")
|
||||||
|
|
||||||
|
with stopwatch("Computing feature importance with permutation importance..."):
|
||||||
|
feature_importance = _compute_feature_importance(settings.model, best_estimator, training_data)
|
||||||
|
|
||||||
|
with stopwatch("Explaining model predictions with SHAP..."):
|
||||||
|
explanation = _compute_shap_explanation(settings.model, best_estimator, training_data)
|
||||||
|
|
||||||
|
print("Predicting probabilities for all cells...")
|
||||||
|
preds = predict_proba(dataset_ensemble, model=best_estimator, task=settings.task, device=settings.device)
|
||||||
|
print(f"Predicted probabilities DataFrame with {len(preds)} entries.")
|
||||||
|
|
||||||
|
summary = Training(
|
||||||
|
path=results_dir,
|
||||||
|
dataset=dataset_ensemble,
|
||||||
|
method=HPOCV(
|
||||||
|
method=settings.search,
|
||||||
|
splitter=settings.splitter,
|
||||||
|
scaler=settings.scaler,
|
||||||
|
normalize=settings.normalize,
|
||||||
|
n_iter=settings.n_iter,
|
||||||
|
hpconfig=model_hpo_config.hp_config,
|
||||||
|
),
|
||||||
|
task=settings.task,
|
||||||
|
target=settings.target,
|
||||||
|
training_set=training_data,
|
||||||
|
model=best_estimator,
|
||||||
|
model_type=settings.model,
|
||||||
|
metrics=metrics,
|
||||||
|
feature_importance=feature_importance,
|
||||||
|
shap_explanation=explanation,
|
||||||
|
predictions=preds,
|
||||||
|
confusion_matrix=confusion_matrix,
|
||||||
|
cv_results=_extract_cv_results(hp_search.cv_results_),
|
||||||
|
leaderboard=None,
|
||||||
|
)
|
||||||
|
summary.save()
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
@ -5,13 +5,17 @@ from typing import Literal
|
||||||
|
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
|
from autogluon.tabular import TabularPredictor
|
||||||
from rich import pretty, traceback
|
from rich import pretty, traceback
|
||||||
from sklearn import set_config
|
from sklearn import set_config
|
||||||
|
from sklearn.pipeline import Pipeline
|
||||||
|
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
from entropice.ml.models import SupportedModel
|
from entropice.ml.models import SupportedModel, is_classifier, is_regressor
|
||||||
|
from entropice.utils.types import Task
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
@ -19,16 +23,53 @@ pretty.install()
|
||||||
set_config(array_api_dispatch=True)
|
set_config(array_api_dispatch=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_right_model_for_task(model: Pipeline | SupportedModel | TabularPredictor, task: Task) -> None:
|
||||||
|
assert hasattr(model, "predict") or hasattr(model, "predict_proba"), (
|
||||||
|
"Model must have a predict or predict_proba method"
|
||||||
|
)
|
||||||
|
is_classification = task in ["binary", "count_regimes", "density_regimes"]
|
||||||
|
is_regression = task in ["count", "density"]
|
||||||
|
assert is_classification != is_regression, "Task must be either classification or regression"
|
||||||
|
assert is_classifier(model) == is_classification, f"Model type does not match task type for task {task}"
|
||||||
|
assert is_regressor(model) == is_regression, f"Model type does not match task type for task {task}"
|
||||||
|
|
||||||
|
|
||||||
|
def _categorize_predictions(
|
||||||
|
preds: pd.Series | np.ndarray,
|
||||||
|
task: Task,
|
||||||
|
) -> pd.Series | pd.Categorical:
|
||||||
|
"""Convert the raw model predictions into a category type series."""
|
||||||
|
if isinstance(preds, np.ndarray):
|
||||||
|
preds = pd.Series(preds)
|
||||||
|
match task:
|
||||||
|
case "binary" | "count_regimes" | "density_regimes":
|
||||||
|
labels_dict = {
|
||||||
|
"binary": ["No RTS", "RTS"],
|
||||||
|
"count_regimes": ["None", "Very Few", "Few", "Several", "Many", "Very Many"],
|
||||||
|
"density_regimes": ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"],
|
||||||
|
}
|
||||||
|
categories = pd.CategoricalDtype(categories=labels_dict[task], ordered=task != "binary")
|
||||||
|
# Check if preds are codes or labels
|
||||||
|
if preds.dtype == "object" or isinstance(preds.iloc[0], str):
|
||||||
|
return pd.Categorical(preds, dtype=categories)
|
||||||
|
else:
|
||||||
|
return pd.Categorical.from_codes(preds.astype(int).to_list(), dtype=categories)
|
||||||
|
case _:
|
||||||
|
return preds
|
||||||
|
|
||||||
|
|
||||||
def predict_proba(
|
def predict_proba(
|
||||||
e: DatasetEnsemble,
|
e: DatasetEnsemble,
|
||||||
model: SupportedModel,
|
model: Pipeline | SupportedModel | TabularPredictor,
|
||||||
|
task: Task,
|
||||||
device: Literal["cpu", "cuda", "torch"] = "cuda",
|
device: Literal["cpu", "cuda", "torch"] = "cuda",
|
||||||
) -> gpd.GeoDataFrame:
|
) -> gpd.GeoDataFrame:
|
||||||
"""Get predicted probabilities for each cell.
|
"""Get predicted probabilities for each cell.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
e (DatasetEnsemble): The dataset ensemble configuration.
|
e (DatasetEnsemble): The dataset ensemble configuration.
|
||||||
model: SupportedModel: The trained model to use for predictions.
|
model: SupportedModel | TabularPredictor: The trained model to use for predictions.
|
||||||
|
task (Task): The task.
|
||||||
device (Literal["cpu", "cuda", "torch"]): The device to use for predictions.
|
device (Literal["cpu", "cuda", "torch"]): The device to use for predictions.
|
||||||
This must match with the state of the model!
|
This must match with the state of the model!
|
||||||
|
|
||||||
|
|
@ -36,6 +77,7 @@ def predict_proba(
|
||||||
gpd.GeoDataFrame: A GeoDataFrame with cell_id, predicted probability, and geometry.
|
gpd.GeoDataFrame: A GeoDataFrame with cell_id, predicted probability, and geometry.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
_assert_right_model_for_task(model, task)
|
||||||
# Predict in batches to avoid memory issues
|
# Predict in batches to avoid memory issues
|
||||||
batch_size = 50000
|
batch_size = 50000
|
||||||
preds = []
|
preds = []
|
||||||
|
|
@ -52,16 +94,27 @@ def predict_proba(
|
||||||
cell_ids = batch.index.to_numpy()
|
cell_ids = batch.index.to_numpy()
|
||||||
cell_geoms = grid_gdf.loc[batch.index, "geometry"].to_numpy()
|
cell_geoms = grid_gdf.loc[batch.index, "geometry"].to_numpy()
|
||||||
|
|
||||||
X_batch = batch.to_numpy(dtype="float64")
|
if isinstance(model, TabularPredictor):
|
||||||
if device == "torch":
|
print(f"Predicting batch of size {len(batch)} ({type(batch)}) with AutoGluon TabularPredictor...")
|
||||||
X_batch = torch.from_numpy(X_batch).to("cuda")
|
batch_preds = model.predict(batch)
|
||||||
elif device == "cuda":
|
print(f"Batch predictions type: {type(batch_preds)}, shape: {batch_preds.shape}")
|
||||||
X_batch = cp.asarray(X_batch)
|
|
||||||
batch_preds = model.predict(X_batch)
|
assert isinstance(batch_preds, pd.DataFrame | pd.Series), (
|
||||||
if isinstance(batch_preds, cp.ndarray):
|
"AutoGluon predict should return a DataFrame or Series"
|
||||||
batch_preds = batch_preds.get()
|
)
|
||||||
elif torch.is_tensor(batch_preds):
|
batch_preds = batch_preds.to_numpy()
|
||||||
batch_preds = batch_preds.cpu().numpy()
|
else:
|
||||||
|
X_batch = batch.to_numpy(dtype="float64")
|
||||||
|
if device == "torch":
|
||||||
|
X_batch = torch.from_numpy(X_batch).to("cuda")
|
||||||
|
elif device == "cuda":
|
||||||
|
X_batch = cp.asarray(X_batch)
|
||||||
|
batch_preds = model.predict(X_batch)
|
||||||
|
if isinstance(batch_preds, cp.ndarray):
|
||||||
|
batch_preds = batch_preds.get()
|
||||||
|
elif torch.is_tensor(batch_preds):
|
||||||
|
batch_preds = batch_preds.cpu().numpy()
|
||||||
|
batch_preds = _categorize_predictions(batch_preds, task=task)
|
||||||
batch_preds = gpd.GeoDataFrame(
|
batch_preds = gpd.GeoDataFrame(
|
||||||
{
|
{
|
||||||
"cell_id": cell_ids,
|
"cell_id": cell_ids,
|
||||||
|
|
|
||||||
|
|
@ -3,16 +3,30 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
|
import cupy as cp
|
||||||
|
import pandas as pd
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from autogluon.tabular import TabularPredictor
|
||||||
from cuml.ensemble import RandomForestClassifier, RandomForestRegressor
|
from cuml.ensemble import RandomForestClassifier, RandomForestRegressor
|
||||||
from cuml.neighbors import KNeighborsClassifier, KNeighborsRegressor
|
from cuml.neighbors import KNeighborsClassifier, KNeighborsRegressor
|
||||||
from entropy import ESPAClassifier
|
from entropy import ESPAClassifier
|
||||||
from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen
|
from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen
|
||||||
|
from sklearn.model_selection import (
|
||||||
|
GroupKFold,
|
||||||
|
GroupShuffleSplit,
|
||||||
|
KFold,
|
||||||
|
ShuffleSplit,
|
||||||
|
StratifiedGroupKFold,
|
||||||
|
StratifiedKFold,
|
||||||
|
StratifiedShuffleSplit,
|
||||||
|
)
|
||||||
|
from sklearn.pipeline import Pipeline
|
||||||
from xgboost.sklearn import XGBClassifier, XGBRegressor
|
from xgboost.sklearn import XGBClassifier, XGBRegressor
|
||||||
|
|
||||||
from entropice.ml.dataset import TrainingSet
|
from entropice.ml.dataset import TrainingSet
|
||||||
from entropice.utils.types import Task
|
from entropice.utils.types import Splitter, Task
|
||||||
|
|
||||||
|
|
||||||
class Distribution(TypedDict):
|
class Distribution(TypedDict):
|
||||||
|
|
@ -35,6 +49,47 @@ type SupportedModel = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_search_space(hp_config: HPConfig) -> dict[str, list | rv_continuous_frozen | rv_discrete_frozen]:
|
||||||
|
"""Convert the HPConfig into a search space dictionary usable by sklearn's RandomizedSearchCV."""
|
||||||
|
search_space = {}
|
||||||
|
for key, dist in hp_config.items():
|
||||||
|
if isinstance(dist, list):
|
||||||
|
search_space[key] = dist
|
||||||
|
continue
|
||||||
|
assert hasattr(scipy.stats, dist["distribution"]), (
|
||||||
|
f"Unknown distribution type for {key}: {dist['distribution']}"
|
||||||
|
)
|
||||||
|
distfn = getattr(scipy.stats, dist["distribution"])
|
||||||
|
search_space[key] = distfn(dist["low"], dist["high"])
|
||||||
|
return search_space
|
||||||
|
|
||||||
|
|
||||||
|
def is_classifier(model: Pipeline | SupportedModel | TabularPredictor) -> bool:
|
||||||
|
"""Check if the model is a classifier."""
|
||||||
|
if isinstance(model, TabularPredictor):
|
||||||
|
return model.problem_type in ["binary", "multiclass"]
|
||||||
|
if isinstance(model, Pipeline):
|
||||||
|
# Check the last step of the pipeline for the model type
|
||||||
|
return is_classifier(model.steps[-1][1])
|
||||||
|
return isinstance(
|
||||||
|
model,
|
||||||
|
(ESPAClassifier, XGBClassifier, RandomForestClassifier, KNeighborsClassifier),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_regressor(model: Pipeline | SupportedModel | TabularPredictor) -> bool:
|
||||||
|
"""Check if the model is a regressor."""
|
||||||
|
if isinstance(model, TabularPredictor):
|
||||||
|
return model.problem_type in ["regression"]
|
||||||
|
if isinstance(model, Pipeline):
|
||||||
|
# Check the last step of the pipeline for the model type
|
||||||
|
return is_regressor(model.steps[-1][1])
|
||||||
|
return isinstance(
|
||||||
|
model,
|
||||||
|
(XGBRegressor, RandomForestRegressor, KNeighborsRegressor),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ModelHPOConfig:
|
class ModelHPOConfig:
|
||||||
"""Model - Hyperparameter Optimization Configuration."""
|
"""Model - Hyperparameter Optimization Configuration."""
|
||||||
|
|
@ -46,32 +101,43 @@ class ModelHPOConfig:
|
||||||
@property
|
@property
|
||||||
def search_space(self) -> dict[str, list | rv_continuous_frozen | rv_discrete_frozen]:
|
def search_space(self) -> dict[str, list | rv_continuous_frozen | rv_discrete_frozen]:
|
||||||
"""Convert the HPConfig into a search space dictionary usable by sklearn's RandomizedSearchCV."""
|
"""Convert the HPConfig into a search space dictionary usable by sklearn's RandomizedSearchCV."""
|
||||||
search_space = {}
|
return get_search_space(self.hp_config)
|
||||||
for key, dist in self.hp_config.items():
|
|
||||||
if isinstance(dist, list):
|
|
||||||
search_space[key] = dist
|
|
||||||
continue
|
|
||||||
assert hasattr(scipy.stats, dist["distribution"]), (
|
|
||||||
f"Unknown distribution type for {key}: {dist['distribution']}"
|
|
||||||
)
|
|
||||||
distfn = getattr(scipy.stats, dist["distribution"])
|
|
||||||
search_space[key] = distfn(dist["low"], dist["high"])
|
|
||||||
return search_space
|
|
||||||
|
|
||||||
|
|
||||||
# Hardcode Search Settings for now
|
# Hardcode Search Settings for now
|
||||||
espa_hpconfig: HPConfig = {
|
espa_hpconfig: HPConfig = {
|
||||||
"eps_cl": {"distribution": "loguniform", "low": 1e-11, "high": 1e-6},
|
"eps_cl": {"distribution": "loguniform", "low": 1e-11, "high": 1e-6},
|
||||||
"eps_e": {"distribution": "loguniform", "low": 1e4, "high": 1e8},
|
"eps_e": {"distribution": "loguniform", "low": 1e4, "high": 1e8},
|
||||||
"initial_K": {"distribution": "randint", "low": 400, "high": 800},
|
"initial_K": {"distribution": "randint", "low": 50, "high": 800},
|
||||||
}
|
}
|
||||||
xgboost_hpconfig: HPConfig = {
|
xgboost_hpconfig: HPConfig = {
|
||||||
"learning_rate": {"distribution": "loguniform", "low": 1e-3, "high": 1e-1},
|
# Learning & Regularization
|
||||||
"n_estimators": {"distribution": "randint", "low": 50, "high": 2000},
|
"learning_rate": {"distribution": "loguniform", "low": 0.01, "high": 0.3},
|
||||||
|
"n_estimators": {"distribution": "randint", "low": 100, "high": 500},
|
||||||
|
# Tree Structure (critical for overfitting control)
|
||||||
|
"max_depth": {"distribution": "randint", "low": 3, "high": 10},
|
||||||
|
# "min_child_weight": {"distribution": "randint", "low": 1, "high": 10},
|
||||||
|
# Feature Sampling (important for high-dimensional data)
|
||||||
|
"colsample_bytree": {"distribution": "uniform", "low": 0.3, "high": 1.0},
|
||||||
|
# "colsample_bylevel": {"distribution": "uniform", "low": 0.3, "high": 1.0},
|
||||||
|
# Row Sampling
|
||||||
|
# "subsample": {"distribution": "uniform", "low": 0.5, "high": 1.0},
|
||||||
|
# Regularization
|
||||||
|
# "reg_alpha": {"distribution": "loguniform", "low": 1e-8, "high": 1.0}, # L1
|
||||||
|
"reg_lambda": {"distribution": "loguniform", "low": 1e-8, "high": 10.0}, # L2
|
||||||
}
|
}
|
||||||
rf_hpconfig: HPConfig = {
|
rf_hpconfig: HPConfig = {
|
||||||
"max_depth": {"distribution": "randint", "low": 5, "high": 50},
|
# Tree Structure
|
||||||
"n_estimators": {"distribution": "randint", "low": 50, "high": 1000},
|
"max_depth": {"distribution": "randint", "low": 10, "high": 50},
|
||||||
|
"n_estimators": {"distribution": "randint", "low": 50, "high": 300},
|
||||||
|
# Split Criteria (critical for >100 features)
|
||||||
|
"max_features": {"distribution": "uniform", "low": 0.1, "high": 0.5}, # sqrt(100) ≈ 10% of features
|
||||||
|
"min_samples_split": {"distribution": "randint", "low": 2, "high": 20},
|
||||||
|
# "min_samples_leaf": {"distribution": "randint", "low": 1, "high": 10},
|
||||||
|
# Regularization (cuML specific)
|
||||||
|
# "min_impurity_decrease": {"distribution": "loguniform", "low": 1e-7, "high": 1e-2},
|
||||||
|
# Bootstrap
|
||||||
|
# "max_samples": {"distribution": "uniform", "low": 0.5, "high": 1.0},
|
||||||
}
|
}
|
||||||
knn_hpconfig: HPConfig = {
|
knn_hpconfig: HPConfig = {
|
||||||
"n_neighbors": {"distribution": "randint", "low": 10, "high": 200},
|
"n_neighbors": {"distribution": "randint", "low": 10, "high": 200},
|
||||||
|
|
@ -134,7 +200,7 @@ def get_model_hpo_config(model: str, task: Task, **model_kwargs) -> ModelHPOConf
|
||||||
raise ValueError(f"Unsupported model/task combination: {model}/{task}")
|
raise ValueError(f"Unsupported model/task combination: {model}/{task}")
|
||||||
|
|
||||||
|
|
||||||
def extract_espa_feature_importance(model: ESPAClassifier, training_data: TrainingSet) -> xr.Dataset:
|
def extract_espa_state(model: ESPAClassifier, training_data: TrainingSet) -> xr.Dataset:
|
||||||
"""Extract the inner state of a trained ESPAClassifier as an xarray Dataset."""
|
"""Extract the inner state of a trained ESPAClassifier as an xarray Dataset."""
|
||||||
# Annotate the state with xarray metadata
|
# Annotate the state with xarray metadata
|
||||||
boxes = list(range(model.K_))
|
boxes = list(range(model.K_))
|
||||||
|
|
@ -157,7 +223,7 @@ def extract_espa_feature_importance(model: ESPAClassifier, training_data: Traini
|
||||||
dims=["feature"],
|
dims=["feature"],
|
||||||
coords={"feature": training_data.feature_names},
|
coords={"feature": training_data.feature_names},
|
||||||
name="feature_weights",
|
name="feature_weights",
|
||||||
attrs={"description": "Feature weights for each box."},
|
attrs={"description": "Weights for each feature."},
|
||||||
)
|
)
|
||||||
state = xr.Dataset(
|
state = xr.Dataset(
|
||||||
{
|
{
|
||||||
|
|
@ -172,130 +238,68 @@ def extract_espa_feature_importance(model: ESPAClassifier, training_data: Traini
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
def extract_xgboost_feature_importance(model: XGBClassifier | XGBRegressor, training_data: TrainingSet) -> xr.Dataset:
|
def extract_espa_feature_importance(model: ESPAClassifier, training_data: TrainingSet) -> pd.DataFrame:
|
||||||
"""Extract feature importance from a trained XGBoost model as an xarray Dataset."""
|
"""Extract feature importance from a trained ESPAClassifier as a pandas DataFrame."""
|
||||||
# Extract XGBoost-specific information
|
weights = model.W_
|
||||||
# Get the underlying booster
|
if isinstance(weights, torch.Tensor):
|
||||||
|
weights = weights.cpu().numpy()
|
||||||
|
elif isinstance(weights, cp.ndarray):
|
||||||
|
weights = weights.get()
|
||||||
|
return pd.DataFrame(
|
||||||
|
{"importance": weights},
|
||||||
|
index=training_data.feature_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_xgboost_feature_importance(model: XGBClassifier | XGBRegressor, training_data: TrainingSet) -> pd.DataFrame:
|
||||||
|
"""Extract feature importance from a trained XGBoost model as a pandas DataFrame."""
|
||||||
booster = model.get_booster()
|
booster = model.get_booster()
|
||||||
|
fi = {}
|
||||||
# Feature importance with different importance types
|
for metric in ["weight", "gain", "cover", "total_gain", "total_cover"]:
|
||||||
# Note: get_score() returns dict with keys like 'f0', 'f1', etc. (feature indices)
|
scores = booster.get_score(importance_type=metric)
|
||||||
importance_weight = booster.get_score(importance_type="weight")
|
fi[metric] = [scores.get(f"f{i}", 0.0) for i in range(len(training_data.feature_names))]
|
||||||
importance_gain = booster.get_score(importance_type="gain")
|
return pd.DataFrame(fi, index=training_data.feature_names)
|
||||||
importance_cover = booster.get_score(importance_type="cover")
|
|
||||||
importance_total_gain = booster.get_score(importance_type="total_gain")
|
|
||||||
importance_total_cover = booster.get_score(importance_type="total_cover")
|
|
||||||
|
|
||||||
# Create aligned arrays for all features (including zero-importance)
|
|
||||||
def align_importance(importance_dict, features):
|
|
||||||
"""Align importance dict to feature list, filling missing with 0.
|
|
||||||
|
|
||||||
XGBoost returns feature indices (f0, f1, ...) as keys, so we need to map them.
|
|
||||||
"""
|
|
||||||
return [importance_dict.get(f"f{i}", 0.0) for i in range(len(features))]
|
|
||||||
|
|
||||||
feature_importance_weight = xr.DataArray(
|
|
||||||
align_importance(importance_weight, training_data.feature_names),
|
|
||||||
dims=["feature"],
|
|
||||||
coords={"feature": training_data.feature_names},
|
|
||||||
name="feature_importance_weight",
|
|
||||||
attrs={"description": "Number of times a feature is used to split the data across all trees."},
|
|
||||||
)
|
|
||||||
feature_importance_gain = xr.DataArray(
|
|
||||||
align_importance(importance_gain, training_data.feature_names),
|
|
||||||
dims=["feature"],
|
|
||||||
coords={"feature": training_data.feature_names},
|
|
||||||
name="feature_importance_gain",
|
|
||||||
attrs={"description": "Average gain across all splits the feature is used in."},
|
|
||||||
)
|
|
||||||
feature_importance_cover = xr.DataArray(
|
|
||||||
align_importance(importance_cover, training_data.feature_names),
|
|
||||||
dims=["feature"],
|
|
||||||
coords={"feature": training_data.feature_names},
|
|
||||||
name="feature_importance_cover",
|
|
||||||
attrs={"description": "Average coverage across all splits the feature is used in."},
|
|
||||||
)
|
|
||||||
feature_importance_total_gain = xr.DataArray(
|
|
||||||
align_importance(importance_total_gain, training_data.feature_names),
|
|
||||||
dims=["feature"],
|
|
||||||
coords={"feature": training_data.feature_names},
|
|
||||||
name="feature_importance_total_gain",
|
|
||||||
attrs={"description": "Total gain across all splits the feature is used in."},
|
|
||||||
)
|
|
||||||
feature_importance_total_cover = xr.DataArray(
|
|
||||||
align_importance(importance_total_cover, training_data.feature_names),
|
|
||||||
dims=["feature"],
|
|
||||||
coords={"feature": training_data.feature_names},
|
|
||||||
name="feature_importance_total_cover",
|
|
||||||
attrs={"description": "Total coverage across all splits the feature is used in."},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store tree information
|
|
||||||
n_trees = booster.num_boosted_rounds()
|
|
||||||
|
|
||||||
state = xr.Dataset(
|
|
||||||
{
|
|
||||||
"feature_importance_weight": feature_importance_weight,
|
|
||||||
"feature_importance_gain": feature_importance_gain,
|
|
||||||
"feature_importance_cover": feature_importance_cover,
|
|
||||||
"feature_importance_total_gain": feature_importance_total_gain,
|
|
||||||
"feature_importance_total_cover": feature_importance_total_cover,
|
|
||||||
},
|
|
||||||
attrs={
|
|
||||||
"description": "Inner state of the best XGBClassifier from RandomizedSearchCV.",
|
|
||||||
"n_trees": n_trees,
|
|
||||||
"objective": str(model.objective),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
def extract_rf_feature_importance(
|
def extract_rf_feature_importance(
|
||||||
model: RandomForestClassifier | RandomForestRegressor, training_data: TrainingSet
|
model: RandomForestClassifier | RandomForestRegressor, training_data: TrainingSet
|
||||||
) -> xr.Dataset:
|
) -> pd.DataFrame:
|
||||||
"""Extract feature importance from a trained RandomForest model as an xarray Dataset."""
|
"""Extract feature importance from a trained RandomForest model as a pandas DataFrame."""
|
||||||
# Extract Random Forest-specific information
|
# Extract Random Forest-specific information
|
||||||
# Note: cuML's RandomForestClassifier doesn't expose individual trees (estimators_)
|
# Note: cuML's RandomForestClassifier doesn't expose individual trees (estimators_)
|
||||||
# like sklearn does, so we can only extract feature importances and model parameters
|
# like sklearn does, so we can only extract feature importances and model parameters
|
||||||
|
|
||||||
# Feature importances (Gini importance)
|
# Feature importances (Gini importance)
|
||||||
feature_importances = model.feature_importances_
|
feature_importances = {"gini": model.feature_importances_}
|
||||||
|
return pd.DataFrame(feature_importances, index=training_data.feature_names)
|
||||||
|
|
||||||
feature_importance = xr.DataArray(
|
|
||||||
feature_importances,
|
|
||||||
dims=["feature"],
|
|
||||||
coords={"feature": training_data.feature_names},
|
|
||||||
name="feature_importance",
|
|
||||||
attrs={"description": "Gini importance (impurity-based feature importance)."},
|
|
||||||
)
|
|
||||||
|
|
||||||
# cuML RF doesn't expose individual trees, so we store model parameters instead
|
def get_splitter(
|
||||||
n_estimators = model.n_estimators
|
splitter: Splitter, n_splits: int
|
||||||
max_depth = model.max_depth
|
) -> (
|
||||||
|
KFold
|
||||||
# OOB score if available
|
| StratifiedKFold
|
||||||
oob_score = None
|
| ShuffleSplit
|
||||||
if hasattr(model, "oob_score_") and model.oob_score:
|
| StratifiedShuffleSplit
|
||||||
oob_score = float(model.oob_score_)
|
| GroupKFold
|
||||||
|
| StratifiedGroupKFold
|
||||||
# cuML RandomForest doesn't provide per-tree statistics like sklearn
|
| GroupShuffleSplit
|
||||||
# Store what we have: feature importances and model configuration
|
):
|
||||||
attrs = {
|
"""Get a scikit-learn splitter object based on the specified splitter type."""
|
||||||
"description": "Inner state of the best RandomForestClassifier from RandomizedSearchCV (cuML).",
|
match splitter:
|
||||||
"n_estimators": int(n_estimators),
|
case "kfold":
|
||||||
"note": "cuML RandomForest does not expose individual tree statistics like sklearn",
|
return KFold(n_splits=n_splits, shuffle=True, random_state=42)
|
||||||
}
|
case "shuffle":
|
||||||
|
return ShuffleSplit(n_splits=n_splits, test_size=0.2, random_state=42)
|
||||||
# Only add optional attributes if they have values
|
case "stratified":
|
||||||
if max_depth is not None:
|
return StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
||||||
attrs["max_depth"] = int(max_depth)
|
case "stratified_shuffle":
|
||||||
if oob_score is not None:
|
return StratifiedShuffleSplit(n_splits=n_splits, test_size=0.2, random_state=42)
|
||||||
attrs["oob_score"] = oob_score
|
case "group":
|
||||||
|
return GroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
||||||
state = xr.Dataset(
|
case "stratified_group":
|
||||||
{
|
return StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
||||||
"feature_importance": feature_importance,
|
case "shuffle_group":
|
||||||
},
|
return GroupShuffleSplit(n_splits=n_splits, test_size=0.2, random_state=42)
|
||||||
attrs=attrs,
|
case _:
|
||||||
)
|
raise ValueError(f"Unsupported splitter type: {splitter}")
|
||||||
return state
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
"""Training of classification models training."""
|
"""DEPRECATED!!! Training of classification models training."""
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cyclopts
|
import cyclopts
|
||||||
|
|
@ -13,15 +12,7 @@ import xarray as xr
|
||||||
from rich import pretty, traceback
|
from rich import pretty, traceback
|
||||||
from sklearn import set_config
|
from sklearn import set_config
|
||||||
from sklearn.metrics import (
|
from sklearn.metrics import (
|
||||||
accuracy_score,
|
|
||||||
confusion_matrix,
|
confusion_matrix,
|
||||||
f1_score,
|
|
||||||
jaccard_score,
|
|
||||||
mean_absolute_error,
|
|
||||||
mean_squared_error,
|
|
||||||
precision_score,
|
|
||||||
r2_score,
|
|
||||||
recall_score,
|
|
||||||
)
|
)
|
||||||
from sklearn.model_selection import KFold, RandomizedSearchCV
|
from sklearn.model_selection import KFold, RandomizedSearchCV
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
@ -30,11 +21,13 @@ from entropice.ml.dataset import DatasetEnsemble, SplittedArrays
|
||||||
from entropice.ml.inference import predict_proba
|
from entropice.ml.inference import predict_proba
|
||||||
from entropice.ml.models import (
|
from entropice.ml.models import (
|
||||||
extract_espa_feature_importance,
|
extract_espa_feature_importance,
|
||||||
|
extract_espa_state,
|
||||||
extract_rf_feature_importance,
|
extract_rf_feature_importance,
|
||||||
extract_xgboost_feature_importance,
|
extract_xgboost_feature_importance,
|
||||||
get_model_hpo_config,
|
get_model_hpo_config,
|
||||||
)
|
)
|
||||||
from entropice.utils.paths import get_cv_results_dir
|
from entropice.utils.metrics import get_metrics, metric_functions
|
||||||
|
from entropice.utils.paths import get_training_results_dir
|
||||||
from entropice.utils.types import Model, TargetDataset, Task
|
from entropice.utils.types import Model, TargetDataset, Task
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
|
|
@ -44,56 +37,9 @@ pretty.install()
|
||||||
cli = cyclopts.App("entropice-training", config=cyclopts.config.Toml("training-config.toml")) # ty:ignore[invalid-argument-type]
|
cli = cyclopts.App("entropice-training", config=cyclopts.config.Toml("training-config.toml")) # ty:ignore[invalid-argument-type]
|
||||||
|
|
||||||
|
|
||||||
def _get_metrics(task: Task) -> tuple[list[str], str]:
|
|
||||||
"""Get the list of metrics for a given task."""
|
|
||||||
if task == "binary":
|
|
||||||
return ["accuracy", "recall", "precision", "f1", "jaccard"], "f1"
|
|
||||||
elif task in ["count_regimes", "density_regimes"]:
|
|
||||||
return [
|
|
||||||
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro", "recall_weighted"
|
|
||||||
"f1_macro",
|
|
||||||
"f1_weighted",
|
|
||||||
"precision_macro",
|
|
||||||
"precision_weighted",
|
|
||||||
"recall_macro",
|
|
||||||
"jaccard_micro",
|
|
||||||
"jaccard_macro",
|
|
||||||
"jaccard_weighted",
|
|
||||||
], "f1_weighted"
|
|
||||||
else:
|
|
||||||
return [
|
|
||||||
"neg_mean_squared_error",
|
|
||||||
"neg_mean_absolute_error",
|
|
||||||
"r2",
|
|
||||||
], "r2"
|
|
||||||
|
|
||||||
|
|
||||||
# Compute other metrics - using predictions directly instead of re-predicting for each metric
|
|
||||||
# Use functools.partial for cleaner metric definitions with non-default parameters
|
|
||||||
_metric_functions = {
|
|
||||||
"accuracy": accuracy_score,
|
|
||||||
"recall": recall_score,
|
|
||||||
"precision": precision_score,
|
|
||||||
"f1": f1_score,
|
|
||||||
"jaccard": jaccard_score,
|
|
||||||
"recall_macro": partial(recall_score, average="macro"),
|
|
||||||
"recall_weighted": partial(recall_score, average="weighted"),
|
|
||||||
"precision_macro": partial(precision_score, average="macro"),
|
|
||||||
"precision_weighted": partial(precision_score, average="weighted"),
|
|
||||||
"f1_macro": partial(f1_score, average="macro"),
|
|
||||||
"f1_weighted": partial(f1_score, average="weighted"),
|
|
||||||
"jaccard_micro": partial(jaccard_score, average="micro"),
|
|
||||||
"jaccard_macro": partial(jaccard_score, average="macro"),
|
|
||||||
"jaccard_weighted": partial(jaccard_score, average="weighted"),
|
|
||||||
"neg_mean_squared_error": mean_squared_error,
|
|
||||||
"neg_mean_absolute_error": mean_absolute_error,
|
|
||||||
"r2": r2_score,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@cyclopts.Parameter("*")
|
@cyclopts.Parameter("*")
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class CVSettings:
|
class RunSettings:
|
||||||
"""Cross-validation settings for model training."""
|
"""Cross-validation settings for model training."""
|
||||||
|
|
||||||
n_iter: int = 2000
|
n_iter: int = 2000
|
||||||
|
|
@ -103,7 +49,7 @@ class CVSettings:
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class TrainingSettings(DatasetEnsemble, CVSettings):
|
class TrainingSettings(DatasetEnsemble, RunSettings):
|
||||||
"""Helper Wrapper to store combined training and dataset ensemble settings."""
|
"""Helper Wrapper to store combined training and dataset ensemble settings."""
|
||||||
|
|
||||||
param_grid: dict
|
param_grid: dict
|
||||||
|
|
@ -115,14 +61,14 @@ class TrainingSettings(DatasetEnsemble, CVSettings):
|
||||||
@cli.default
|
@cli.default
|
||||||
def random_cv(
|
def random_cv(
|
||||||
dataset_ensemble: DatasetEnsemble,
|
dataset_ensemble: DatasetEnsemble,
|
||||||
settings: CVSettings = CVSettings(),
|
settings: RunSettings = RunSettings(),
|
||||||
experiment: str | None = None,
|
experiment: str | None = None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""Perform random cross-validation on the training dataset.
|
"""Perform random cross-validation on the training dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_ensemble (DatasetEnsemble): The dataset ensemble configuration.
|
dataset_ensemble (DatasetEnsemble): The dataset ensemble configuration.
|
||||||
settings (CVSettings): The cross-validation settings.
|
settings (RunSettings): The cross-validation settings.
|
||||||
experiment (str | None): Optional experiment name for results directory.
|
experiment (str | None): Optional experiment name for results directory.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
@ -136,7 +82,7 @@ def random_cv(
|
||||||
model_hpo_config = get_model_hpo_config(settings.model, settings.task)
|
model_hpo_config = get_model_hpo_config(settings.model, settings.task)
|
||||||
print(f"Using model: {settings.model} with parameters: {model_hpo_config.hp_config}")
|
print(f"Using model: {settings.model} with parameters: {model_hpo_config.hp_config}")
|
||||||
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
||||||
metrics, refit = _get_metrics(settings.task)
|
metrics, refit = get_metrics(settings.task)
|
||||||
search = RandomizedSearchCV(
|
search = RandomizedSearchCV(
|
||||||
model_hpo_config.model,
|
model_hpo_config.model,
|
||||||
model_hpo_config.search_space,
|
model_hpo_config.search_space,
|
||||||
|
|
@ -175,12 +121,14 @@ def random_cv(
|
||||||
)
|
)
|
||||||
print(f"{refit.replace('_', ' ').capitalize()} on test set: {test_score:.3f}")
|
print(f"{refit.replace('_', ' ').capitalize()} on test set: {test_score:.3f}")
|
||||||
|
|
||||||
results_dir = get_cv_results_dir(
|
results_dir = get_training_results_dir(
|
||||||
experiment=experiment,
|
experiment=experiment,
|
||||||
name="random_search",
|
name="random_search",
|
||||||
grid=dataset_ensemble.grid,
|
grid=dataset_ensemble.grid,
|
||||||
level=dataset_ensemble.level,
|
level=dataset_ensemble.level,
|
||||||
task=settings.task,
|
task=settings.task,
|
||||||
|
target=settings.target,
|
||||||
|
model_type=settings.model,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store the search settings
|
# Store the search settings
|
||||||
|
|
@ -221,9 +169,9 @@ def random_cv(
|
||||||
|
|
||||||
# Compute and StoreMetrics
|
# Compute and StoreMetrics
|
||||||
y = training_data.y.as_numpy()
|
y = training_data.y.as_numpy()
|
||||||
test_metrics = {metric: _metric_functions[metric](y.test, y_pred.test) for metric in metrics}
|
test_metrics = {metric: metric_functions[metric](y.test, y_pred.test) for metric in metrics}
|
||||||
train_metrics = {metric: _metric_functions[metric](y.train, y_pred.train) for metric in metrics}
|
train_metrics = {metric: metric_functions[metric](y.train, y_pred.train) for metric in metrics}
|
||||||
combined_metrics = {metric: _metric_functions[metric](y.combined, y_pred.combined) for metric in metrics}
|
combined_metrics = {metric: metric_functions[metric](y.combined, y_pred.combined) for metric in metrics}
|
||||||
all_metrics = {
|
all_metrics = {
|
||||||
"test_metrics": test_metrics,
|
"test_metrics": test_metrics,
|
||||||
"train_metrics": train_metrics,
|
"train_metrics": train_metrics,
|
||||||
|
|
@ -255,31 +203,30 @@ def random_cv(
|
||||||
|
|
||||||
# Get the inner state of the best estimator
|
# Get the inner state of the best estimator
|
||||||
if settings.model == "espa":
|
if settings.model == "espa":
|
||||||
state = extract_espa_feature_importance(best_estimator, training_data)
|
state = extract_espa_state(best_estimator, training_data)
|
||||||
state_file = results_dir / "best_estimator_state.nc"
|
state_file = results_dir / "best_estimator_state.nc"
|
||||||
print(f"Storing best estimator state to {state_file}")
|
print(f"Storing best estimator state to {state_file}")
|
||||||
state.to_netcdf(state_file, engine="h5netcdf")
|
state.to_netcdf(state_file, engine="h5netcdf")
|
||||||
|
fi = extract_espa_feature_importance(best_estimator, training_data)
|
||||||
|
fi_file = results_dir / "best_estimator_feature_importance.parquet"
|
||||||
|
print(f"Storing best estimator feature importance to {fi_file}")
|
||||||
|
fi.to_parquet(fi_file)
|
||||||
|
|
||||||
elif settings.model == "xgboost":
|
elif settings.model == "xgboost":
|
||||||
state = extract_xgboost_feature_importance(best_estimator, training_data)
|
fi = extract_xgboost_feature_importance(best_estimator, training_data)
|
||||||
state_file = results_dir / "best_estimator_state.nc"
|
fi_file = results_dir / "best_estimator_feature_importance.parquet"
|
||||||
print(f"Storing best estimator state to {state_file}")
|
print(f"Storing best estimator feature importance to {fi_file}")
|
||||||
state.to_netcdf(state_file, engine="h5netcdf")
|
fi.to_parquet(fi_file)
|
||||||
|
|
||||||
elif settings.model == "rf":
|
elif settings.model == "rf":
|
||||||
state = extract_rf_feature_importance(best_estimator, training_data)
|
fi = extract_rf_feature_importance(best_estimator, training_data)
|
||||||
state_file = results_dir / "best_estimator_state.nc"
|
fi_file = results_dir / "best_estimator_feature_importance.parquet"
|
||||||
print(f"Storing best estimator state to {state_file}")
|
print(f"Storing best estimator feature importance to {fi_file}")
|
||||||
state.to_netcdf(state_file, engine="h5netcdf")
|
fi.to_parquet(fi_file)
|
||||||
|
|
||||||
# Predict probabilities for all cells
|
# Predict probabilities for all cells
|
||||||
print("Predicting probabilities for all cells...")
|
print("Predicting probabilities for all cells...")
|
||||||
preds = predict_proba(dataset_ensemble, model=best_estimator, device=device)
|
preds = predict_proba(dataset_ensemble, model=best_estimator, task=settings.task, device=device)
|
||||||
if training_data.targets["y"].dtype == "category":
|
|
||||||
preds["predicted"] = preds["predicted"].astype("category")
|
|
||||||
preds["predicted"] = preds["predicted"].cat.set_categories(
|
|
||||||
training_data.targets["y"].cat.categories, ordered=True
|
|
||||||
)
|
|
||||||
print(f"Predicted probabilities DataFrame with {len(preds)} entries.")
|
print(f"Predicted probabilities DataFrame with {len(preds)} entries.")
|
||||||
preds_file = results_dir / "predicted_probabilities.parquet"
|
preds_file = results_dir / "predicted_probabilities.parquet"
|
||||||
print(f"Storing predicted probabilities to {preds_file}")
|
print(f"Storing predicted probabilities to {preds_file}")
|
||||||
63
src/entropice/utils/metrics.py
Normal file
63
src/entropice/utils/metrics.py
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
"""Metrics for model evaluation and hyperparameter optimization."""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from sklearn.metrics import (
|
||||||
|
accuracy_score,
|
||||||
|
f1_score,
|
||||||
|
jaccard_score,
|
||||||
|
mean_absolute_error,
|
||||||
|
mean_squared_error,
|
||||||
|
precision_score,
|
||||||
|
r2_score,
|
||||||
|
recall_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
from entropice.utils.types import Task
|
||||||
|
|
||||||
|
|
||||||
|
def get_metrics(task: Task) -> tuple[list[str], str]:
|
||||||
|
"""Get the list of metrics for a given task."""
|
||||||
|
if task == "binary":
|
||||||
|
return ["accuracy", "recall", "precision", "f1", "jaccard"], "f1"
|
||||||
|
elif task in ["count_regimes", "density_regimes"]:
|
||||||
|
return [
|
||||||
|
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro", "recall_weighted"
|
||||||
|
"f1_macro",
|
||||||
|
"f1_weighted",
|
||||||
|
"precision_macro",
|
||||||
|
"precision_weighted",
|
||||||
|
"recall_macro",
|
||||||
|
"jaccard_micro",
|
||||||
|
"jaccard_macro",
|
||||||
|
"jaccard_weighted",
|
||||||
|
], "f1_weighted"
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
"neg_mean_squared_error",
|
||||||
|
"neg_mean_absolute_error",
|
||||||
|
"r2",
|
||||||
|
], "r2"
|
||||||
|
|
||||||
|
|
||||||
|
# Compute other metrics - using predictions directly instead of re-predicting for each metric
|
||||||
|
# Use functools.partial for cleaner metric definitions with non-default parameters
|
||||||
|
metric_functions = {
|
||||||
|
"accuracy": accuracy_score,
|
||||||
|
"recall": recall_score,
|
||||||
|
"precision": precision_score,
|
||||||
|
"f1": f1_score,
|
||||||
|
"jaccard": jaccard_score,
|
||||||
|
"recall_macro": partial(recall_score, average="macro"),
|
||||||
|
"recall_weighted": partial(recall_score, average="weighted"),
|
||||||
|
"precision_macro": partial(precision_score, average="macro"),
|
||||||
|
"precision_weighted": partial(precision_score, average="weighted"),
|
||||||
|
"f1_macro": partial(f1_score, average="macro"),
|
||||||
|
"f1_weighted": partial(f1_score, average="weighted"),
|
||||||
|
"jaccard_micro": partial(jaccard_score, average="micro"),
|
||||||
|
"jaccard_macro": partial(jaccard_score, average="macro"),
|
||||||
|
"jaccard_weighted": partial(jaccard_score, average="weighted"),
|
||||||
|
"neg_mean_squared_error": mean_squared_error,
|
||||||
|
"neg_mean_absolute_error": mean_absolute_error,
|
||||||
|
"r2": r2_score,
|
||||||
|
}
|
||||||
|
|
@ -6,7 +6,7 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from entropice.utils.types import Grid, Task, TemporalMode
|
from entropice.utils.types import Grid, Model, TargetDataset, Task, TemporalMode
|
||||||
|
|
||||||
DATA_DIR = (
|
DATA_DIR = (
|
||||||
Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice"
|
Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice"
|
||||||
|
|
@ -147,12 +147,14 @@ def get_dataset_stats_cache() -> Path:
|
||||||
return cache_dir / "dataset_stats.pckl"
|
return cache_dir / "dataset_stats.pckl"
|
||||||
|
|
||||||
|
|
||||||
def get_cv_results_dir(
|
def get_training_results_dir(
|
||||||
experiment: str | None,
|
experiment: str | None,
|
||||||
name: str,
|
name: str,
|
||||||
grid: Grid,
|
grid: Grid,
|
||||||
level: int,
|
level: int,
|
||||||
task: Task,
|
task: Task,
|
||||||
|
target: TargetDataset,
|
||||||
|
model_type: Model | None = None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
|
|
@ -161,24 +163,9 @@ def get_cv_results_dir(
|
||||||
experiment_dir.mkdir(parents=True, exist_ok=True)
|
experiment_dir.mkdir(parents=True, exist_ok=True)
|
||||||
else:
|
else:
|
||||||
experiment_dir = RESULTS_DIR
|
experiment_dir = RESULTS_DIR
|
||||||
results_dir = experiment_dir / f"{gridname}_{name}_cv{now}_{task}"
|
parts = [gridname, name, now, task, target]
|
||||||
results_dir.mkdir(parents=True, exist_ok=True)
|
if model_type is not None:
|
||||||
return results_dir
|
parts.append(model_type)
|
||||||
|
results_dir = experiment_dir / "_".join(parts)
|
||||||
|
|
||||||
def get_autogluon_results_dir(
|
|
||||||
experiment: str | None,
|
|
||||||
grid: Grid,
|
|
||||||
level: int,
|
|
||||||
task: Task,
|
|
||||||
) -> Path:
|
|
||||||
gridname = _get_gridname(grid, level)
|
|
||||||
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
||||||
if experiment is not None:
|
|
||||||
experiment_dir = RESULTS_DIR / experiment
|
|
||||||
experiment_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
else:
|
|
||||||
experiment_dir = RESULTS_DIR
|
|
||||||
results_dir = experiment_dir / f"{gridname}_autogluon_{now}_{task}"
|
|
||||||
results_dir.mkdir(parents=True, exist_ok=True)
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
return results_dir
|
return results_dir
|
||||||
|
|
|
||||||
238
src/entropice/utils/training.py
Normal file
238
src/entropice/utils/training.py
Normal file
|
|
@ -0,0 +1,238 @@
|
||||||
|
import pickle
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from functools import cached_property
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import cupy as cp
|
||||||
|
import geopandas as gpd
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import toml
|
||||||
|
import torch
|
||||||
|
import xarray as xr
|
||||||
|
from shap import Explanation
|
||||||
|
|
||||||
|
from entropice.ml.dataset import DatasetEnsemble, TrainingSet
|
||||||
|
from entropice.ml.models import (
|
||||||
|
HPConfig,
|
||||||
|
extract_espa_state,
|
||||||
|
get_search_space,
|
||||||
|
)
|
||||||
|
from entropice.utils.types import HPSearch, Model, Scaler, Splitter, TargetDataset, Task
|
||||||
|
|
||||||
|
type ndarray = np.ndarray | torch.Tensor | cp.ndarray
|
||||||
|
|
||||||
|
|
||||||
|
def move_data_to_device(data: ndarray, device: Literal["torch", "cuda", "cpu"]) -> ndarray:
|
||||||
|
"""Move the given data to the specified device (CPU, CUDA, or PyTorch tensor)."""
|
||||||
|
match device:
|
||||||
|
case "torch":
|
||||||
|
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
return torch.asarray(data, device=torch_device)
|
||||||
|
case "cuda":
|
||||||
|
if isinstance(data, cp.ndarray):
|
||||||
|
return data
|
||||||
|
with cp.cuda.Device(0):
|
||||||
|
return cp.asarray(data)
|
||||||
|
case "cpu":
|
||||||
|
if isinstance(data, np.ndarray):
|
||||||
|
return data
|
||||||
|
elif isinstance(data, torch.Tensor):
|
||||||
|
return data.cpu().numpy()
|
||||||
|
elif isinstance(data, cp.ndarray):
|
||||||
|
return data.get()
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown device: {device}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HPOCV:
|
||||||
|
method: HPSearch
|
||||||
|
splitter: Splitter
|
||||||
|
scaler: Scaler
|
||||||
|
normalize: bool
|
||||||
|
n_iter: int
|
||||||
|
hpconfig: HPConfig
|
||||||
|
|
||||||
|
@property
|
||||||
|
def search_space(self):
|
||||||
|
return get_search_space(self.hpconfig)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AutoML:
|
||||||
|
time_budget: int
|
||||||
|
preset: str
|
||||||
|
hpo: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Training:
|
||||||
|
"""Configuration and results of a training run.
|
||||||
|
|
||||||
|
A training run involves the complete approach of training a machine learning model, currently:
|
||||||
|
|
||||||
|
1. A hyperparameter optimization using cross-validation
|
||||||
|
2. Use of AutoML techniques
|
||||||
|
|
||||||
|
Thus, a training run is always defined as:
|
||||||
|
|
||||||
|
f(dataset, method) -> (model, metrics)
|
||||||
|
|
||||||
|
Metrics refer to a simple dataframe in long format, with the columns: "metric", "split", "value".
|
||||||
|
Split is either "train", "test" or "complete".
|
||||||
|
"""
|
||||||
|
|
||||||
|
path: Path
|
||||||
|
dataset: DatasetEnsemble
|
||||||
|
method: HPOCV | AutoML
|
||||||
|
task: Task
|
||||||
|
target: TargetDataset
|
||||||
|
training_set: TrainingSet # TODO: Store Training Set to improve loading time (?)
|
||||||
|
model: Any
|
||||||
|
model_type: Model
|
||||||
|
metrics: pd.DataFrame
|
||||||
|
feature_importance: pd.DataFrame
|
||||||
|
shap_explanation: Explanation
|
||||||
|
predictions: gpd.GeoDataFrame
|
||||||
|
confusion_matrix: xr.Dataset | None # only for classification tasks
|
||||||
|
cv_results: pd.DataFrame | None # only for HPOCV
|
||||||
|
leaderboard: pd.DataFrame | None # only for AutoGluon
|
||||||
|
|
||||||
|
def __repr__(self) -> str: # noqa: D105
|
||||||
|
return (
|
||||||
|
f"Training("
|
||||||
|
f"path={self.path}, "
|
||||||
|
f"dataset={self.dataset.grid}-{self.dataset.level}{self.dataset.members}, "
|
||||||
|
f"method={type(self.method).__name__}, "
|
||||||
|
f"task={self.task}, "
|
||||||
|
f"target={self.target})"
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def metric_names(self) -> list[str]:
|
||||||
|
"""Get the list of metric names from the metrics DataFrame."""
|
||||||
|
return self.metrics["metric"].unique().tolist()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_state(self) -> xr.Dataset | pd.DataFrame | None:
|
||||||
|
"""Get the inner state of the trained model, if available."""
|
||||||
|
if self.model_type == "espa":
|
||||||
|
return extract_espa_state(self.model, self.training_set)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
"""Save the training results to the specified path."""
|
||||||
|
self.path.mkdir(parents=True, exist_ok=True)
|
||||||
|
config_file = self.path / "training_config.toml"
|
||||||
|
model_file = self.path / "model.pkl"
|
||||||
|
metrics_file = self.path / "metrics.parquet"
|
||||||
|
feature_importance_file = self.path / "feature_importance.parquet"
|
||||||
|
explanations_file = self.path / "shap_explanation.pkl"
|
||||||
|
predictions_file = self.path / "predictions.parquet"
|
||||||
|
# Save config
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
toml.dump(
|
||||||
|
{
|
||||||
|
"dataset": asdict(self.dataset),
|
||||||
|
"method": asdict(self.method),
|
||||||
|
"method_type": type(self.method).__name__,
|
||||||
|
"task": self.task,
|
||||||
|
"target": self.target,
|
||||||
|
"model_type": self.model_type,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save model, metrics, explanations and predictions
|
||||||
|
model_file.write_bytes(pickle.dumps(self.model))
|
||||||
|
self.metrics.to_parquet(metrics_file)
|
||||||
|
self.feature_importance.to_parquet(feature_importance_file)
|
||||||
|
explanations_file.write_bytes(pickle.dumps(self.shap_explanation))
|
||||||
|
self.predictions.to_parquet(predictions_file)
|
||||||
|
|
||||||
|
# Save the confusion matrix if it exists
|
||||||
|
if self.confusion_matrix is not None:
|
||||||
|
cm_file = self.path / "confusion_matrix.nc"
|
||||||
|
self.confusion_matrix.to_netcdf(cm_file, engine="h5netcdf")
|
||||||
|
|
||||||
|
if self.cv_results is not None:
|
||||||
|
results_file = self.path / "search_results.parquet"
|
||||||
|
self.cv_results.to_parquet(results_file)
|
||||||
|
|
||||||
|
# Save the leaderboard if it exists
|
||||||
|
if self.leaderboard is not None:
|
||||||
|
leaderboard_file = self.path / "leaderboard.parquet"
|
||||||
|
self.leaderboard.to_parquet(leaderboard_file)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path: Path, device: Literal["cpu", "cuda"] = "cpu") -> "Training":
|
||||||
|
"""Load a training run from the specified path."""
|
||||||
|
config_file = path / "training_config.toml"
|
||||||
|
model_file = path / "model.pkl"
|
||||||
|
metrics_file = path / "metrics.parquet"
|
||||||
|
feature_importance_file = path / "feature_importance.parquet"
|
||||||
|
predictions_file = path / "predictions.parquet"
|
||||||
|
cm_file = path / "confusion_matrix.nc"
|
||||||
|
cv_results_file = path / "search_results.parquet"
|
||||||
|
leaderboard_file = path / "leaderboard.parquet"
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
with open(config_file) as f:
|
||||||
|
config = toml.load(f)
|
||||||
|
|
||||||
|
task = config["task"]
|
||||||
|
target = config["target"]
|
||||||
|
model_type = config["model_type"]
|
||||||
|
|
||||||
|
dataset = DatasetEnsemble(**config["dataset"])
|
||||||
|
training_set = dataset.create_training_set(task, target, device)
|
||||||
|
|
||||||
|
method_type = config["method_type"]
|
||||||
|
if method_type == "HPOCV":
|
||||||
|
method = HPOCV(**config["method"])
|
||||||
|
elif method_type == "AutoML":
|
||||||
|
method = AutoML(**config["method"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown method type: {method_type}")
|
||||||
|
|
||||||
|
# Load model, metrics, explanations and predictions
|
||||||
|
model = pickle.loads(model_file.read_bytes())
|
||||||
|
metrics = pd.read_parquet(metrics_file)
|
||||||
|
feature_importance = pd.read_parquet(feature_importance_file)
|
||||||
|
shap_explanation = pickle.loads((path / "shap_explanation.pkl").read_bytes())
|
||||||
|
predictions = gpd.read_parquet(predictions_file)
|
||||||
|
|
||||||
|
# Load confusion matrix if it exists
|
||||||
|
confusion_matrix = None
|
||||||
|
if cm_file.exists():
|
||||||
|
confusion_matrix = xr.load_dataset(cm_file, engine="h5netcdf")
|
||||||
|
|
||||||
|
cv_results = None
|
||||||
|
if cv_results_file.exists():
|
||||||
|
cv_results = pd.read_parquet(cv_results_file)
|
||||||
|
|
||||||
|
# Load leaderboard if it exists
|
||||||
|
leaderboard = None
|
||||||
|
if leaderboard_file.exists():
|
||||||
|
leaderboard = pd.read_parquet(leaderboard_file)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
path=path,
|
||||||
|
dataset=dataset,
|
||||||
|
method=method,
|
||||||
|
task=task,
|
||||||
|
target=target,
|
||||||
|
training_set=training_set,
|
||||||
|
model=model,
|
||||||
|
model_type=model_type,
|
||||||
|
metrics=metrics,
|
||||||
|
feature_importance=feature_importance,
|
||||||
|
shap_explanation=shap_explanation,
|
||||||
|
predictions=predictions,
|
||||||
|
confusion_matrix=confusion_matrix,
|
||||||
|
cv_results=cv_results,
|
||||||
|
leaderboard=leaderboard,
|
||||||
|
)
|
||||||
|
|
@ -19,10 +19,15 @@ type TargetDataset = Literal["darts_v1", "darts_mllabels"]
|
||||||
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
|
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
|
||||||
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
|
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
|
||||||
type Task = Literal["binary", "count_regimes", "density_regimes", "count", "density"]
|
type Task = Literal["binary", "count_regimes", "density_regimes", "count", "density"]
|
||||||
# TODO: Consider implementing a "timeseries" temporal mode
|
# TODO: Consider implementing a "timeseries" and "event" temporal mode
|
||||||
type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023]
|
type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023]
|
||||||
type Model = Literal["espa", "xgboost", "rf", "knn"]
|
type Model = Literal["espa", "xgboost", "rf", "knn", "autogluon"]
|
||||||
type Stage = Literal["train", "inference", "visualization"]
|
type Stage = Literal["train", "inference", "visualization"]
|
||||||
|
type HPSearch = Literal["random"] # TODO Grid and Bayesian search
|
||||||
|
type Splitter = Literal[
|
||||||
|
"kfold", "stratified", "shuffle", "stratified_shuffle", "group", "stratified_group", "shuffle_group"
|
||||||
|
]
|
||||||
|
type Scaler = Literal["standard", "minmax", "maxabs", "quantile", "none"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ import shutil
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
from entropice.ml.training import CVSettings, random_cv
|
from entropice.ml.randomsearch import RunSettings, random_cv
|
||||||
from entropice.utils.types import Model, Task
|
from entropice.utils.types import Model, Task
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -90,7 +90,7 @@ class TestRandomCV:
|
||||||
- All output files are created
|
- All output files are created
|
||||||
"""
|
"""
|
||||||
# Use darts_v1 as the primary target for all tests
|
# Use darts_v1 as the primary target for all tests
|
||||||
settings = CVSettings(
|
settings = RunSettings(
|
||||||
n_iter=3,
|
n_iter=3,
|
||||||
task=task,
|
task=task,
|
||||||
target="darts_v1",
|
target="darts_v1",
|
||||||
|
|
@ -141,7 +141,7 @@ class TestRandomCV:
|
||||||
- xgboost: Uses CUDA without array API dispatch
|
- xgboost: Uses CUDA without array API dispatch
|
||||||
- rf/knn: GPU-accelerated via cuML
|
- rf/knn: GPU-accelerated via cuML
|
||||||
"""
|
"""
|
||||||
settings = CVSettings(
|
settings = RunSettings(
|
||||||
n_iter=3,
|
n_iter=3,
|
||||||
task="binary", # Simple binary task for device testing
|
task="binary", # Simple binary task for device testing
|
||||||
target="darts_v1",
|
target="darts_v1",
|
||||||
|
|
@ -168,7 +168,7 @@ class TestRandomCV:
|
||||||
|
|
||||||
def test_random_cv_with_mllabels(self, test_ensemble, cleanup_results):
|
def test_random_cv_with_mllabels(self, test_ensemble, cleanup_results):
|
||||||
"""Test random_cv with multi-label target dataset."""
|
"""Test random_cv with multi-label target dataset."""
|
||||||
settings = CVSettings(
|
settings = RunSettings(
|
||||||
n_iter=3,
|
n_iter=3,
|
||||||
task="binary",
|
task="binary",
|
||||||
target="darts_mllabels",
|
target="darts_mllabels",
|
||||||
|
|
@ -199,7 +199,7 @@ class TestRandomCV:
|
||||||
add_lonlat=True,
|
add_lonlat=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
settings = CVSettings(
|
settings = RunSettings(
|
||||||
n_iter=3,
|
n_iter=3,
|
||||||
task="binary",
|
task="binary",
|
||||||
target="darts_v1",
|
target="darts_v1",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue