Switch to predict_classes

This commit is contained in:
Tobias Hölzer 2025-11-09 01:39:36 +01:00
parent d4a747d800
commit b2cfddfead
3 changed files with 373 additions and 192 deletions

View file

@ -18,13 +18,14 @@ pretty.install()
set_config(array_api_dispatch=True)
def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier):
def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier, classes: list) -> gpd.GeoDataFrame:
"""Get predicted probabilities for each cell.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
clf (ESPAClassifier): The trained classifier to use for predictions.
classes (list): List of class names.
Returns:
list: A list of predicted probabilities for each cell.
@ -46,11 +47,11 @@ def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifi
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
X_batch = X_batch.to_numpy(dtype="float32")
X_batch = torch.asarray(X_batch, device=0)
batch_preds = clf.predict_proba(X_batch)[:, 1].cpu().numpy()
batch_preds = clf.predict(X_batch).cpu().numpy()
batch_preds = gpd.GeoDataFrame(
{
"cell_id": cell_ids,
"predicted_proba": batch_preds,
"predicted_class": [classes[i] for i in batch_preds],
"geometry": cell_geoms,
},
crs="epsg:3413",

View file

@ -29,13 +29,14 @@ pretty.install()
set_config(array_api_dispatch=True)
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, robust: bool = False):
"""Perform random cross-validation on the training dataset.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
robust (bool, optional): Whether to use robust training. Defaults to False.
"""
data = get_train_dataset_file(grid=grid, level=level)
@ -59,7 +60,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
"initial_K": randint(20, 400),
}
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42)
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
cv = KFold(n_splits=5, shuffle=True, random_state=42)
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
search = RandomizedSearchCV(
@ -69,7 +70,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
n_jobs=16,
cv=cv,
random_state=42,
verbose=10,
verbose=3,
scoring=metrics,
refit="f1",
)
@ -114,6 +115,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
},
"cv_splits": cv.get_n_splits(),
"metrics": metrics,
"classes": ["No RTS", "RTS"],
}
settings_file = results_dir / "search_settings.toml"
print(f"Storing search settings to {settings_file}")
@ -183,7 +185,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
# Predict probabilities for all cells
print("Predicting probabilities for all cells...")
preds = predict_proba(grid=grid, level=level, clf=best_estimator)
preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=settings["classes"])
preds_file = results_dir / "predicted_probabilities.parquet"
print(f"Storing predicted probabilities to {preds_file}")
preds.to_parquet(preds_file)

View file

@ -36,11 +36,12 @@ def get_available_result_files() -> list[Path]:
return sorted(result_files, reverse=True) # Most recent first
def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame:
def load_and_prepare_results(file_path: Path, settings: dict, k_bin_width: int = 40) -> pd.DataFrame:
"""Load results file and prepare binned columns.
Args:
file_path: Path to the results parquet file.
settings: Dictionary of search settings.
k_bin_width: Width of bins for initial_K parameter.
Returns:
@ -50,21 +51,21 @@ def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataF
results = pd.read_parquet(file_path)
# Automatically determine bin width for initial_K based on data range
k_min = results["initial_K"].min()
k_max = results["initial_K"].max()
k_min = settings["param_grid"]["initial_K"]["low"]
k_max = settings["param_grid"]["initial_K"]["high"]
# Use configurable bin width, adapted to actual data range
k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width)
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False)
# Automatically create logarithmic bins for epsilon parameters based on data range
# Use 10 bins spanning the actual data range
eps_cl_min = np.log10(results["eps_cl"].min())
eps_cl_max = np.log10(results["eps_cl"].max())
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10)
eps_cl_min = np.log10(settings["param_grid"]["eps_cl"]["low"])
eps_cl_max = np.log10(settings["param_grid"]["eps_cl"]["high"])
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=int(eps_cl_max - eps_cl_min + 1))
eps_e_min = np.log10(results["eps_e"].min())
eps_e_max = np.log10(results["eps_e"].max())
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10)
eps_e_min = np.log10(settings["param_grid"]["eps_e"]["low"])
eps_e_max = np.log10(settings["param_grid"]["eps_e"]["high"])
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=int(eps_e_max - eps_e_min + 1))
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
@ -184,29 +185,46 @@ def extract_common_features(model_state: xr.Dataset) -> xr.DataArray | None:
return common_feature_array
def _extract_probs_as_xdggs(preds: gpd.GeoDataFrame, settings: dict) -> xr.DataArray:
grid = settings["grid"]
level = settings["level"]
probs = xr.DataArray(
data=preds["predicted_proba"].to_numpy(),
dims=["cell_ids"],
coords={"cell_ids": preds["cell_id"].to_numpy()},
)
gridinfo = {
"grid_name": "h3" if grid == "hex" else grid,
"level": level,
}
if grid == "healpix":
gridinfo["indexing_scheme"] = "nested"
probs.cell_ids.attrs = gridinfo
probs = xdggs.decode(probs)
return probs
def _plot_prediction_map(preds: gpd.GeoDataFrame):
return preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron")
# return probs.dggs.explore()
return preds.explore(column="predicted_class", cmap="Set3", legend=True, tiles="CartoDB positron")
def _plot_prediction_class_distribution(preds: gpd.GeoDataFrame):
"""Create a bar chart showing the count of each predicted class.
Args:
preds: GeoDataFrame with predicted classes.
Returns:
Altair chart showing class distribution.
"""
df = pd.DataFrame({"predicted_class": preds["predicted_class"].to_numpy()})
counts = df["predicted_class"].value_counts().reset_index()
counts.columns = ["class", "count"]
counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2)
chart = (
alt.Chart(counts)
.mark_bar()
.encode(
x=alt.X("class:N", title="Predicted Class"),
y=alt.Y("count:Q", title="Number of Cells"),
color=alt.Color("class:N", title="Class", legend=None),
tooltip=[
alt.Tooltip("class:N", title="Class"),
alt.Tooltip("count:Q", title="Count"),
alt.Tooltip("percentage:Q", format=".2f", title="Percentage (%)"),
],
)
.properties(
width=400,
height=300,
title="Predicted Class Distribution",
)
)
return chart
def _plot_k_binned(
@ -740,6 +758,55 @@ def _plot_common_features(common_array: xr.DataArray):
return chart
def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str, color_param: str):
"""Create a scatter plot comparing two metrics with parameter-based coloring.
Args:
results: DataFrame containing the results with metric columns.
x_metric: Name of the metric to plot on x-axis (e.g., 'precision', 'accuracy').
y_metric: Name of the metric to plot on y-axis (e.g., 'recall', 'jaccard').
color_param: Parameter to use for coloring ('initial_K', 'eps_cl', or 'eps_e').
Returns:
Altair chart showing the metric comparison.
"""
# Determine color scale based on parameter
if color_param in ["eps_cl", "eps_e"]:
color_scale = alt.Scale(type="log", scheme="viridis")
else:
color_scale = alt.Scale(scheme="viridis")
chart = (
alt.Chart(results)
.mark_circle(size=60, opacity=0.7)
.encode(
x=alt.X(f"mean_test_{x_metric}:Q", title=x_metric.capitalize(), scale=alt.Scale(zero=False)),
y=alt.Y(f"mean_test_{y_metric}:Q", title=y_metric.capitalize(), scale=alt.Scale(zero=False)),
color=alt.Color(
f"{color_param}:Q",
scale=color_scale,
title=color_param,
),
tooltip=[
alt.Tooltip(f"mean_test_{x_metric}:Q", format=".4f", title=x_metric.capitalize()),
alt.Tooltip(f"mean_test_{y_metric}:Q", format=".4f", title=y_metric.capitalize()),
alt.Tooltip("mean_test_f1:Q", format=".4f", title="F1"),
alt.Tooltip("initial_K:Q", format=".0f", title="initial_K"),
alt.Tooltip("eps_cl:Q", format=".2e", title="eps_cl"),
alt.Tooltip("eps_e:Q", format=".2e", title="eps_e"),
],
)
.properties(
width=400,
height=400,
)
.interactive()
)
return chart
def main():
"""Run Streamlit dashboard application."""
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
@ -769,7 +836,8 @@ def main():
# Load and prepare data with default bin width (will be reloaded with custom width later)
with st.spinner("Loading results..."):
results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=40)
settings = toml.load(results_dir / "search_settings.toml")["settings"]
results = load_and_prepare_results(results_dir / "search_results.parquet", settings, k_bin_width=40)
model_state = load_and_prepare_model_state(results_dir / "best_estimator_state.nc")
n_features = model_state.sizes["feature"]
model_state["feature_weights"] *= n_features
@ -777,29 +845,55 @@ def main():
era5_feature_array = extract_era5_features(model_state)
common_feature_array = extract_common_features(model_state)
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
settings = toml.load(results_dir / "search_settings.toml")["settings"]
probs = _extract_probs_as_xdggs(predictions, settings)
st.sidebar.success(f"Loaded {len(results)} results")
# Metric selection
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
selected_metric = st.sidebar.selectbox(
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
)
# Display some basic statistics first (lightweight)
st.header("Dataset Overview")
col1, col2, col3 = st.columns(3)
st.header("Parameter-Search Overview")
# Show total runs and best model info
col1, col2 = st.columns(2)
with col1:
st.metric("Total Runs", len(results))
with col2:
best_score = results[f"mean_test_{selected_metric}"].max()
st.metric(f"Best {selected_metric.capitalize()}", f"{best_score:.4f}")
# Best model based on F1 score
best_f1_idx = results["mean_test_f1"].idxmax()
st.metric("Best Model Index (by F1)", f"#{best_f1_idx}")
# Show best parameters for the best model
with st.expander("Best Model Parameters (by F1)", expanded=True):
best_params = results.loc[best_f1_idx, ["initial_K", "eps_cl", "eps_e", "mean_test_f1", "std_test_f1"]]
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("initial_K", f"{best_params['initial_K']:.0f}")
with col2:
st.metric("eps_cl", f"{best_params['eps_cl']:.2e}")
with col3:
best_idx = results[f"mean_test_{selected_metric}"].idxmax()
best_k = results.loc[best_idx, "initial_K"]
st.metric("Best K", f"{best_k:.0f}")
st.metric("eps_e", f"{best_params['eps_e']:.2e}")
with col4:
st.metric("F1 Score", f"{best_params['mean_test_f1']:.4f}")
with col5:
st.metric("F1 Std", f"{best_params['std_test_f1']:.4f}")
# Show all metrics
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
cols = st.columns(len(available_metrics))
for idx, metric in enumerate(available_metrics):
with cols[idx]:
best_score = results.loc[best_f1_idx, f"mean_test_{metric}"]
st.metric(f"{metric.capitalize()}", f"{best_score:.4f}")
# Create tabs for different visualizations
tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Inference Analysis"])
with tab1:
# Metric selection - only used in this tab
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
selected_metric = st.selectbox(
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
)
# Show best parameters
with st.expander("Best Parameters"):
@ -807,10 +901,6 @@ def main():
best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]]
st.dataframe(best_params.to_frame().T, width="content")
# Create tabs for different visualizations
tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Predictions Map"])
with tab1:
# Main plots
st.header(f"Visualization for {selected_metric.capitalize()}")
@ -847,7 +937,9 @@ def main():
# Reload data if bin width changed from default
if k_bin_width != 40:
with st.spinner("Re-binning data..."):
results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=k_bin_width)
results = load_and_prepare_results(
results_dir / "search_results.parquet", settings, k_bin_width=k_bin_width
)
# K-binned plots
col1, col2 = st.columns(2)
@ -885,6 +977,30 @@ def main():
chart4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}")
st.altair_chart(chart4, use_container_width=True)
# Metric comparison plots
st.header("Metric Comparisons")
# Color parameter selection
color_param = st.selectbox(
"Select Color Parameter",
options=["initial_K", "eps_cl", "eps_e"],
help="Choose which parameter to use for coloring the scatter plots",
)
col1, col2 = st.columns(2)
with col1:
st.subheader("Recall vs Precision")
with st.spinner("Generating Recall vs Precision plot..."):
recall_precision_chart = _plot_metric_comparison(results, "precision", "recall", color_param)
st.altair_chart(recall_precision_chart, use_container_width=True)
with col2:
st.subheader("Accuracy vs Jaccard")
with st.spinner("Generating Accuracy vs Jaccard plot..."):
accuracy_jaccard_chart = _plot_metric_comparison(results, "accuracy", "jaccard", color_param)
st.altair_chart(accuracy_jaccard_chart, use_container_width=True)
# Optional: Raw data table
with st.expander("View Raw Results Data"):
st.dataframe(results, width="stretch")
@ -993,7 +1109,8 @@ def main():
# Embedding features analysis (if present)
if embedding_feature_array is not None:
st.subheader(":artificial_satellite: Embedding Feature Analysis")
with st.container(border=True):
st.header(":artificial_satellite: Embedding Feature Analysis")
st.markdown(
"""
Analysis of embedding features showing which aggregations, bands, and years
@ -1044,7 +1161,8 @@ def main():
# ERA5 features analysis (if present)
if era5_feature_array is not None:
st.subheader(":partly_sunny: ERA5 Feature Analysis")
with st.container(border=True):
st.header(":partly_sunny: ERA5 Feature Analysis")
st.markdown(
"""
Analysis of ERA5 climate features showing which variables and time periods
@ -1093,7 +1211,8 @@ def main():
# Common features analysis (if present)
if common_feature_array is not None:
st.subheader(":world_map: Common Feature Analysis")
with st.container(border=True):
st.header(":world_map: Common Feature Analysis")
st.markdown(
"""
Analysis of common features including cell area, water area, land area, land ratio,
@ -1133,7 +1252,8 @@ def main():
st.markdown(
"""
**Interpretation:**
- **cell_area, water_area, land_area**: Spatial extent features that may indicate size-related patterns
- **cell_area, water_area, land_area**: Spatial extent features that may indicate
size-related patterns
- **land_ratio**: Proportion of land vs water in each cell
- **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
- Positive weights indicate features that increase the probability of the positive class
@ -1144,13 +1264,71 @@ def main():
st.info("No common features found in this model.")
with tab3:
# Map visualization
st.header("Predictions Map")
st.markdown("Map showing predicted classes from the best estimator")
with st.spinner("Generating map..."):
# Inference analysis
st.header("Inference Analysis")
st.markdown("Comprehensive analysis of model predictions on the evaluation dataset")
# Summary statistics
st.subheader("Prediction Statistics")
col1, col2 = st.columns(2)
with col1:
total_cells = len(predictions)
st.metric("Total Cells", f"{total_cells:,}")
with col2:
n_classes = predictions["predicted_class"].nunique()
st.metric("Number of Classes", n_classes)
# Class distribution visualization
st.subheader("Class Distribution")
with st.spinner("Generating class distribution..."):
class_dist_chart = _plot_prediction_class_distribution(predictions)
st.altair_chart(class_dist_chart, use_container_width=True)
st.markdown(
"""
**Interpretation:**
- Shows the balance between predicted classes
- Class imbalance may indicate regional patterns or model bias
- Each bar represents the count of cells predicted for that class
"""
)
# Interactive map
st.subheader("Interactive Prediction Map")
st.markdown("Explore predictions spatially with the interactive map below")
with st.spinner("Generating interactive map..."):
chart_map = _plot_prediction_map(predictions)
st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[])
# Additional statistics in expander
with st.expander("Detailed Prediction Statistics"):
st.write("**Class Distribution:**")
class_counts = predictions["predicted_class"].value_counts().sort_index()
# Create columns for better layout
n_cols = min(5, len(class_counts))
cols = st.columns(n_cols)
for idx, (class_label, count) in enumerate(class_counts.items()):
percentage = count / len(predictions) * 100
with cols[idx % n_cols]:
st.metric(f"Class {class_label}", f"{count:,} ({percentage:.2f}%)")
# Show detailed table
st.write("**Detailed Class Breakdown:**")
class_df = pd.DataFrame(
{
"Class": class_counts.index,
"Count": class_counts.to_numpy(),
"Percentage": (class_counts.to_numpy() / len(predictions) * 100).round(2),
}
)
st.dataframe(class_df, width="stretch", hide_index=True)
if __name__ == "__main__":
main()