Switch to predict_classes

This commit is contained in:
Tobias Hölzer 2025-11-09 01:39:36 +01:00
parent d4a747d800
commit b2cfddfead
3 changed files with 373 additions and 192 deletions

View file

@ -18,13 +18,14 @@ pretty.install()
set_config(array_api_dispatch=True) set_config(array_api_dispatch=True)
def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier): def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier, classes: list) -> gpd.GeoDataFrame:
"""Get predicted probabilities for each cell. """Get predicted probabilities for each cell.
Args: Args:
grid (Literal["hex", "healpix"]): The grid type to use. grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use. level (int): The grid level to use.
clf (ESPAClassifier): The trained classifier to use for predictions. clf (ESPAClassifier): The trained classifier to use for predictions.
classes (list): List of class names.
Returns: Returns:
list: A list of predicted probabilities for each cell. list: A list of predicted probabilities for each cell.
@ -46,11 +47,11 @@ def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifi
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy() cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
X_batch = X_batch.to_numpy(dtype="float32") X_batch = X_batch.to_numpy(dtype="float32")
X_batch = torch.asarray(X_batch, device=0) X_batch = torch.asarray(X_batch, device=0)
batch_preds = clf.predict_proba(X_batch)[:, 1].cpu().numpy() batch_preds = clf.predict(X_batch).cpu().numpy()
batch_preds = gpd.GeoDataFrame( batch_preds = gpd.GeoDataFrame(
{ {
"cell_id": cell_ids, "cell_id": cell_ids,
"predicted_proba": batch_preds, "predicted_class": [classes[i] for i in batch_preds],
"geometry": cell_geoms, "geometry": cell_geoms,
}, },
crs="epsg:3413", crs="epsg:3413",

View file

@ -29,13 +29,14 @@ pretty.install()
set_config(array_api_dispatch=True) set_config(array_api_dispatch=True)
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000): def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, robust: bool = False):
"""Perform random cross-validation on the training dataset. """Perform random cross-validation on the training dataset.
Args: Args:
grid (Literal["hex", "healpix"]): The grid type to use. grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use. level (int): The grid level to use.
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000. n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
robust (bool, optional): Whether to use robust training. Defaults to False.
""" """
data = get_train_dataset_file(grid=grid, level=level) data = get_train_dataset_file(grid=grid, level=level)
@ -59,7 +60,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
"initial_K": randint(20, 400), "initial_K": randint(20, 400),
} }
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42) clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
cv = KFold(n_splits=5, shuffle=True, random_state=42) cv = KFold(n_splits=5, shuffle=True, random_state=42)
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
search = RandomizedSearchCV( search = RandomizedSearchCV(
@ -69,7 +70,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
n_jobs=16, n_jobs=16,
cv=cv, cv=cv,
random_state=42, random_state=42,
verbose=10, verbose=3,
scoring=metrics, scoring=metrics,
refit="f1", refit="f1",
) )
@ -114,6 +115,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
}, },
"cv_splits": cv.get_n_splits(), "cv_splits": cv.get_n_splits(),
"metrics": metrics, "metrics": metrics,
"classes": ["No RTS", "RTS"],
} }
settings_file = results_dir / "search_settings.toml" settings_file = results_dir / "search_settings.toml"
print(f"Storing search settings to {settings_file}") print(f"Storing search settings to {settings_file}")
@ -183,7 +185,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
# Predict probabilities for all cells # Predict probabilities for all cells
print("Predicting probabilities for all cells...") print("Predicting probabilities for all cells...")
preds = predict_proba(grid=grid, level=level, clf=best_estimator) preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=settings["classes"])
preds_file = results_dir / "predicted_probabilities.parquet" preds_file = results_dir / "predicted_probabilities.parquet"
print(f"Storing predicted probabilities to {preds_file}") print(f"Storing predicted probabilities to {preds_file}")
preds.to_parquet(preds_file) preds.to_parquet(preds_file)

View file

@ -36,11 +36,12 @@ def get_available_result_files() -> list[Path]:
return sorted(result_files, reverse=True) # Most recent first return sorted(result_files, reverse=True) # Most recent first
def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame: def load_and_prepare_results(file_path: Path, settings: dict, k_bin_width: int = 40) -> pd.DataFrame:
"""Load results file and prepare binned columns. """Load results file and prepare binned columns.
Args: Args:
file_path: Path to the results parquet file. file_path: Path to the results parquet file.
settings: Dictionary of search settings.
k_bin_width: Width of bins for initial_K parameter. k_bin_width: Width of bins for initial_K parameter.
Returns: Returns:
@ -50,21 +51,21 @@ def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataF
results = pd.read_parquet(file_path) results = pd.read_parquet(file_path)
# Automatically determine bin width for initial_K based on data range # Automatically determine bin width for initial_K based on data range
k_min = results["initial_K"].min() k_min = settings["param_grid"]["initial_K"]["low"]
k_max = results["initial_K"].max() k_max = settings["param_grid"]["initial_K"]["high"]
# Use configurable bin width, adapted to actual data range # Use configurable bin width, adapted to actual data range
k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width) k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width)
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False) results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False)
# Automatically create logarithmic bins for epsilon parameters based on data range # Automatically create logarithmic bins for epsilon parameters based on data range
# Use 10 bins spanning the actual data range # Use 10 bins spanning the actual data range
eps_cl_min = np.log10(results["eps_cl"].min()) eps_cl_min = np.log10(settings["param_grid"]["eps_cl"]["low"])
eps_cl_max = np.log10(results["eps_cl"].max()) eps_cl_max = np.log10(settings["param_grid"]["eps_cl"]["high"])
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10) eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=int(eps_cl_max - eps_cl_min + 1))
eps_e_min = np.log10(results["eps_e"].min()) eps_e_min = np.log10(settings["param_grid"]["eps_e"]["low"])
eps_e_max = np.log10(results["eps_e"].max()) eps_e_max = np.log10(settings["param_grid"]["eps_e"]["high"])
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10) eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=int(eps_e_max - eps_e_min + 1))
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins) results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins) results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
@ -184,29 +185,46 @@ def extract_common_features(model_state: xr.Dataset) -> xr.DataArray | None:
return common_feature_array return common_feature_array
def _extract_probs_as_xdggs(preds: gpd.GeoDataFrame, settings: dict) -> xr.DataArray:
grid = settings["grid"]
level = settings["level"]
probs = xr.DataArray(
data=preds["predicted_proba"].to_numpy(),
dims=["cell_ids"],
coords={"cell_ids": preds["cell_id"].to_numpy()},
)
gridinfo = {
"grid_name": "h3" if grid == "hex" else grid,
"level": level,
}
if grid == "healpix":
gridinfo["indexing_scheme"] = "nested"
probs.cell_ids.attrs = gridinfo
probs = xdggs.decode(probs)
return probs
def _plot_prediction_map(preds: gpd.GeoDataFrame): def _plot_prediction_map(preds: gpd.GeoDataFrame):
return preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron") return preds.explore(column="predicted_class", cmap="Set3", legend=True, tiles="CartoDB positron")
# return probs.dggs.explore()
def _plot_prediction_class_distribution(preds: gpd.GeoDataFrame):
"""Create a bar chart showing the count of each predicted class.
Args:
preds: GeoDataFrame with predicted classes.
Returns:
Altair chart showing class distribution.
"""
df = pd.DataFrame({"predicted_class": preds["predicted_class"].to_numpy()})
counts = df["predicted_class"].value_counts().reset_index()
counts.columns = ["class", "count"]
counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2)
chart = (
alt.Chart(counts)
.mark_bar()
.encode(
x=alt.X("class:N", title="Predicted Class"),
y=alt.Y("count:Q", title="Number of Cells"),
color=alt.Color("class:N", title="Class", legend=None),
tooltip=[
alt.Tooltip("class:N", title="Class"),
alt.Tooltip("count:Q", title="Count"),
alt.Tooltip("percentage:Q", format=".2f", title="Percentage (%)"),
],
)
.properties(
width=400,
height=300,
title="Predicted Class Distribution",
)
)
return chart
def _plot_k_binned( def _plot_k_binned(
@ -740,6 +758,55 @@ def _plot_common_features(common_array: xr.DataArray):
return chart return chart
def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str, color_param: str):
"""Create a scatter plot comparing two metrics with parameter-based coloring.
Args:
results: DataFrame containing the results with metric columns.
x_metric: Name of the metric to plot on x-axis (e.g., 'precision', 'accuracy').
y_metric: Name of the metric to plot on y-axis (e.g., 'recall', 'jaccard').
color_param: Parameter to use for coloring ('initial_K', 'eps_cl', or 'eps_e').
Returns:
Altair chart showing the metric comparison.
"""
# Determine color scale based on parameter
if color_param in ["eps_cl", "eps_e"]:
color_scale = alt.Scale(type="log", scheme="viridis")
else:
color_scale = alt.Scale(scheme="viridis")
chart = (
alt.Chart(results)
.mark_circle(size=60, opacity=0.7)
.encode(
x=alt.X(f"mean_test_{x_metric}:Q", title=x_metric.capitalize(), scale=alt.Scale(zero=False)),
y=alt.Y(f"mean_test_{y_metric}:Q", title=y_metric.capitalize(), scale=alt.Scale(zero=False)),
color=alt.Color(
f"{color_param}:Q",
scale=color_scale,
title=color_param,
),
tooltip=[
alt.Tooltip(f"mean_test_{x_metric}:Q", format=".4f", title=x_metric.capitalize()),
alt.Tooltip(f"mean_test_{y_metric}:Q", format=".4f", title=y_metric.capitalize()),
alt.Tooltip("mean_test_f1:Q", format=".4f", title="F1"),
alt.Tooltip("initial_K:Q", format=".0f", title="initial_K"),
alt.Tooltip("eps_cl:Q", format=".2e", title="eps_cl"),
alt.Tooltip("eps_e:Q", format=".2e", title="eps_e"),
],
)
.properties(
width=400,
height=400,
)
.interactive()
)
return chart
def main(): def main():
"""Run Streamlit dashboard application.""" """Run Streamlit dashboard application."""
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide") st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
@ -769,7 +836,8 @@ def main():
# Load and prepare data with default bin width (will be reloaded with custom width later) # Load and prepare data with default bin width (will be reloaded with custom width later)
with st.spinner("Loading results..."): with st.spinner("Loading results..."):
results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=40) settings = toml.load(results_dir / "search_settings.toml")["settings"]
results = load_and_prepare_results(results_dir / "search_results.parquet", settings, k_bin_width=40)
model_state = load_and_prepare_model_state(results_dir / "best_estimator_state.nc") model_state = load_and_prepare_model_state(results_dir / "best_estimator_state.nc")
n_features = model_state.sizes["feature"] n_features = model_state.sizes["feature"]
model_state["feature_weights"] *= n_features model_state["feature_weights"] *= n_features
@ -777,29 +845,55 @@ def main():
era5_feature_array = extract_era5_features(model_state) era5_feature_array = extract_era5_features(model_state)
common_feature_array = extract_common_features(model_state) common_feature_array = extract_common_features(model_state)
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413") predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
settings = toml.load(results_dir / "search_settings.toml")["settings"]
probs = _extract_probs_as_xdggs(predictions, settings)
st.sidebar.success(f"Loaded {len(results)} results") st.sidebar.success(f"Loaded {len(results)} results")
# Metric selection
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
selected_metric = st.sidebar.selectbox(
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
)
# Display some basic statistics first (lightweight) # Display some basic statistics first (lightweight)
st.header("Dataset Overview") st.header("Parameter-Search Overview")
col1, col2, col3 = st.columns(3)
# Show total runs and best model info
col1, col2 = st.columns(2)
with col1: with col1:
st.metric("Total Runs", len(results)) st.metric("Total Runs", len(results))
with col2: with col2:
best_score = results[f"mean_test_{selected_metric}"].max() # Best model based on F1 score
st.metric(f"Best {selected_metric.capitalize()}", f"{best_score:.4f}") best_f1_idx = results["mean_test_f1"].idxmax()
st.metric("Best Model Index (by F1)", f"#{best_f1_idx}")
# Show best parameters for the best model
with st.expander("Best Model Parameters (by F1)", expanded=True):
best_params = results.loc[best_f1_idx, ["initial_K", "eps_cl", "eps_e", "mean_test_f1", "std_test_f1"]]
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("initial_K", f"{best_params['initial_K']:.0f}")
with col2:
st.metric("eps_cl", f"{best_params['eps_cl']:.2e}")
with col3: with col3:
best_idx = results[f"mean_test_{selected_metric}"].idxmax() st.metric("eps_e", f"{best_params['eps_e']:.2e}")
best_k = results.loc[best_idx, "initial_K"] with col4:
st.metric("Best K", f"{best_k:.0f}") st.metric("F1 Score", f"{best_params['mean_test_f1']:.4f}")
with col5:
st.metric("F1 Std", f"{best_params['std_test_f1']:.4f}")
# Show all metrics
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
cols = st.columns(len(available_metrics))
for idx, metric in enumerate(available_metrics):
with cols[idx]:
best_score = results.loc[best_f1_idx, f"mean_test_{metric}"]
st.metric(f"{metric.capitalize()}", f"{best_score:.4f}")
# Create tabs for different visualizations
tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Inference Analysis"])
with tab1:
# Metric selection - only used in this tab
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
selected_metric = st.selectbox(
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
)
# Show best parameters # Show best parameters
with st.expander("Best Parameters"): with st.expander("Best Parameters"):
@ -807,10 +901,6 @@ def main():
best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]] best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]]
st.dataframe(best_params.to_frame().T, width="content") st.dataframe(best_params.to_frame().T, width="content")
# Create tabs for different visualizations
tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Predictions Map"])
with tab1:
# Main plots # Main plots
st.header(f"Visualization for {selected_metric.capitalize()}") st.header(f"Visualization for {selected_metric.capitalize()}")
@ -847,7 +937,9 @@ def main():
# Reload data if bin width changed from default # Reload data if bin width changed from default
if k_bin_width != 40: if k_bin_width != 40:
with st.spinner("Re-binning data..."): with st.spinner("Re-binning data..."):
results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=k_bin_width) results = load_and_prepare_results(
results_dir / "search_results.parquet", settings, k_bin_width=k_bin_width
)
# K-binned plots # K-binned plots
col1, col2 = st.columns(2) col1, col2 = st.columns(2)
@ -885,6 +977,30 @@ def main():
chart4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}") chart4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}")
st.altair_chart(chart4, use_container_width=True) st.altair_chart(chart4, use_container_width=True)
# Metric comparison plots
st.header("Metric Comparisons")
# Color parameter selection
color_param = st.selectbox(
"Select Color Parameter",
options=["initial_K", "eps_cl", "eps_e"],
help="Choose which parameter to use for coloring the scatter plots",
)
col1, col2 = st.columns(2)
with col1:
st.subheader("Recall vs Precision")
with st.spinner("Generating Recall vs Precision plot..."):
recall_precision_chart = _plot_metric_comparison(results, "precision", "recall", color_param)
st.altair_chart(recall_precision_chart, use_container_width=True)
with col2:
st.subheader("Accuracy vs Jaccard")
with st.spinner("Generating Accuracy vs Jaccard plot..."):
accuracy_jaccard_chart = _plot_metric_comparison(results, "accuracy", "jaccard", color_param)
st.altair_chart(accuracy_jaccard_chart, use_container_width=True)
# Optional: Raw data table # Optional: Raw data table
with st.expander("View Raw Results Data"): with st.expander("View Raw Results Data"):
st.dataframe(results, width="stretch") st.dataframe(results, width="stretch")
@ -993,7 +1109,8 @@ def main():
# Embedding features analysis (if present) # Embedding features analysis (if present)
if embedding_feature_array is not None: if embedding_feature_array is not None:
st.subheader(":artificial_satellite: Embedding Feature Analysis") with st.container(border=True):
st.header(":artificial_satellite: Embedding Feature Analysis")
st.markdown( st.markdown(
""" """
Analysis of embedding features showing which aggregations, bands, and years Analysis of embedding features showing which aggregations, bands, and years
@ -1044,7 +1161,8 @@ def main():
# ERA5 features analysis (if present) # ERA5 features analysis (if present)
if era5_feature_array is not None: if era5_feature_array is not None:
st.subheader(":partly_sunny: ERA5 Feature Analysis") with st.container(border=True):
st.header(":partly_sunny: ERA5 Feature Analysis")
st.markdown( st.markdown(
""" """
Analysis of ERA5 climate features showing which variables and time periods Analysis of ERA5 climate features showing which variables and time periods
@ -1093,7 +1211,8 @@ def main():
# Common features analysis (if present) # Common features analysis (if present)
if common_feature_array is not None: if common_feature_array is not None:
st.subheader(":world_map: Common Feature Analysis") with st.container(border=True):
st.header(":world_map: Common Feature Analysis")
st.markdown( st.markdown(
""" """
Analysis of common features including cell area, water area, land area, land ratio, Analysis of common features including cell area, water area, land area, land ratio,
@ -1133,7 +1252,8 @@ def main():
st.markdown( st.markdown(
""" """
**Interpretation:** **Interpretation:**
- **cell_area, water_area, land_area**: Spatial extent features that may indicate size-related patterns - **cell_area, water_area, land_area**: Spatial extent features that may indicate
size-related patterns
- **land_ratio**: Proportion of land vs water in each cell - **land_ratio**: Proportion of land vs water in each cell
- **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns - **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
- Positive weights indicate features that increase the probability of the positive class - Positive weights indicate features that increase the probability of the positive class
@ -1144,13 +1264,71 @@ def main():
st.info("No common features found in this model.") st.info("No common features found in this model.")
with tab3: with tab3:
# Map visualization # Inference analysis
st.header("Predictions Map") st.header("Inference Analysis")
st.markdown("Map showing predicted classes from the best estimator") st.markdown("Comprehensive analysis of model predictions on the evaluation dataset")
with st.spinner("Generating map..."):
# Summary statistics
st.subheader("Prediction Statistics")
col1, col2 = st.columns(2)
with col1:
total_cells = len(predictions)
st.metric("Total Cells", f"{total_cells:,}")
with col2:
n_classes = predictions["predicted_class"].nunique()
st.metric("Number of Classes", n_classes)
# Class distribution visualization
st.subheader("Class Distribution")
with st.spinner("Generating class distribution..."):
class_dist_chart = _plot_prediction_class_distribution(predictions)
st.altair_chart(class_dist_chart, use_container_width=True)
st.markdown(
"""
**Interpretation:**
- Shows the balance between predicted classes
- Class imbalance may indicate regional patterns or model bias
- Each bar represents the count of cells predicted for that class
"""
)
# Interactive map
st.subheader("Interactive Prediction Map")
st.markdown("Explore predictions spatially with the interactive map below")
with st.spinner("Generating interactive map..."):
chart_map = _plot_prediction_map(predictions) chart_map = _plot_prediction_map(predictions)
st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[]) st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[])
# Additional statistics in expander
with st.expander("Detailed Prediction Statistics"):
st.write("**Class Distribution:**")
class_counts = predictions["predicted_class"].value_counts().sort_index()
# Create columns for better layout
n_cols = min(5, len(class_counts))
cols = st.columns(n_cols)
for idx, (class_label, count) in enumerate(class_counts.items()):
percentage = count / len(predictions) * 100
with cols[idx % n_cols]:
st.metric(f"Class {class_label}", f"{count:,} ({percentage:.2f}%)")
# Show detailed table
st.write("**Detailed Class Breakdown:**")
class_df = pd.DataFrame(
{
"Class": class_counts.index,
"Count": class_counts.to_numpy(),
"Percentage": (class_counts.to_numpy() / len(predictions) * 100).round(2),
}
)
st.dataframe(class_df, width="stretch", hide_index=True)
if __name__ == "__main__": if __name__ == "__main__":
main() main()