From b2cfddfeadd1752803cc9184f981f53de6acfbdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sun, 9 Nov 2025 01:39:36 +0100 Subject: [PATCH] Switch to predict_classes --- src/entropice/inference.py | 7 +- src/entropice/training.py | 10 +- src/entropice/training_analysis_dashboard.py | 548 ++++++++++++------- 3 files changed, 373 insertions(+), 192 deletions(-) diff --git a/src/entropice/inference.py b/src/entropice/inference.py index 9b7263a..a3c9f31 100644 --- a/src/entropice/inference.py +++ b/src/entropice/inference.py @@ -18,13 +18,14 @@ pretty.install() set_config(array_api_dispatch=True) -def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier): +def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier, classes: list) -> gpd.GeoDataFrame: """Get predicted probabilities for each cell. Args: grid (Literal["hex", "healpix"]): The grid type to use. level (int): The grid level to use. clf (ESPAClassifier): The trained classifier to use for predictions. + classes (list): List of class names. Returns: list: A list of predicted probabilities for each cell. @@ -46,11 +47,11 @@ def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifi cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy() X_batch = X_batch.to_numpy(dtype="float32") X_batch = torch.asarray(X_batch, device=0) - batch_preds = clf.predict_proba(X_batch)[:, 1].cpu().numpy() + batch_preds = clf.predict(X_batch).cpu().numpy() batch_preds = gpd.GeoDataFrame( { "cell_id": cell_ids, - "predicted_proba": batch_preds, + "predicted_class": [classes[i] for i in batch_preds], "geometry": cell_geoms, }, crs="epsg:3413", diff --git a/src/entropice/training.py b/src/entropice/training.py index 44ee0cb..1364ad2 100644 --- a/src/entropice/training.py +++ b/src/entropice/training.py @@ -29,13 +29,14 @@ pretty.install() set_config(array_api_dispatch=True) -def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000): +def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, robust: bool = False): """Perform random cross-validation on the training dataset. Args: grid (Literal["hex", "healpix"]): The grid type to use. level (int): The grid level to use. n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000. + robust (bool, optional): Whether to use robust training. Defaults to False. """ data = get_train_dataset_file(grid=grid, level=level) @@ -59,7 +60,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000): "initial_K": randint(20, 400), } - clf = ESPAClassifier(20, 0.1, 0.1, random_state=42) + clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust) cv = KFold(n_splits=5, shuffle=True, random_state=42) metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU search = RandomizedSearchCV( @@ -69,7 +70,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000): n_jobs=16, cv=cv, random_state=42, - verbose=10, + verbose=3, scoring=metrics, refit="f1", ) @@ -114,6 +115,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000): }, "cv_splits": cv.get_n_splits(), "metrics": metrics, + "classes": ["No RTS", "RTS"], } settings_file = results_dir / "search_settings.toml" print(f"Storing search settings to {settings_file}") @@ -183,7 +185,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000): # Predict probabilities for all cells print("Predicting probabilities for all cells...") - preds = predict_proba(grid=grid, level=level, clf=best_estimator) + preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=settings["classes"]) preds_file = results_dir / "predicted_probabilities.parquet" print(f"Storing predicted probabilities to {preds_file}") preds.to_parquet(preds_file) diff --git a/src/entropice/training_analysis_dashboard.py b/src/entropice/training_analysis_dashboard.py index 605b532..9e623b6 100644 --- a/src/entropice/training_analysis_dashboard.py +++ b/src/entropice/training_analysis_dashboard.py @@ -36,11 +36,12 @@ def get_available_result_files() -> list[Path]: return sorted(result_files, reverse=True) # Most recent first -def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame: +def load_and_prepare_results(file_path: Path, settings: dict, k_bin_width: int = 40) -> pd.DataFrame: """Load results file and prepare binned columns. Args: file_path: Path to the results parquet file. + settings: Dictionary of search settings. k_bin_width: Width of bins for initial_K parameter. Returns: @@ -50,21 +51,21 @@ def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataF results = pd.read_parquet(file_path) # Automatically determine bin width for initial_K based on data range - k_min = results["initial_K"].min() - k_max = results["initial_K"].max() + k_min = settings["param_grid"]["initial_K"]["low"] + k_max = settings["param_grid"]["initial_K"]["high"] # Use configurable bin width, adapted to actual data range k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width) results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False) # Automatically create logarithmic bins for epsilon parameters based on data range # Use 10 bins spanning the actual data range - eps_cl_min = np.log10(results["eps_cl"].min()) - eps_cl_max = np.log10(results["eps_cl"].max()) - eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10) + eps_cl_min = np.log10(settings["param_grid"]["eps_cl"]["low"]) + eps_cl_max = np.log10(settings["param_grid"]["eps_cl"]["high"]) + eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=int(eps_cl_max - eps_cl_min + 1)) - eps_e_min = np.log10(results["eps_e"].min()) - eps_e_max = np.log10(results["eps_e"].max()) - eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10) + eps_e_min = np.log10(settings["param_grid"]["eps_e"]["low"]) + eps_e_max = np.log10(settings["param_grid"]["eps_e"]["high"]) + eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=int(eps_e_max - eps_e_min + 1)) results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins) results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins) @@ -184,29 +185,46 @@ def extract_common_features(model_state: xr.Dataset) -> xr.DataArray | None: return common_feature_array -def _extract_probs_as_xdggs(preds: gpd.GeoDataFrame, settings: dict) -> xr.DataArray: - grid = settings["grid"] - level = settings["level"] - - probs = xr.DataArray( - data=preds["predicted_proba"].to_numpy(), - dims=["cell_ids"], - coords={"cell_ids": preds["cell_id"].to_numpy()}, - ) - gridinfo = { - "grid_name": "h3" if grid == "hex" else grid, - "level": level, - } - if grid == "healpix": - gridinfo["indexing_scheme"] = "nested" - probs.cell_ids.attrs = gridinfo - probs = xdggs.decode(probs) - return probs - - def _plot_prediction_map(preds: gpd.GeoDataFrame): - return preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron") - # return probs.dggs.explore() + return preds.explore(column="predicted_class", cmap="Set3", legend=True, tiles="CartoDB positron") + + +def _plot_prediction_class_distribution(preds: gpd.GeoDataFrame): + """Create a bar chart showing the count of each predicted class. + + Args: + preds: GeoDataFrame with predicted classes. + + Returns: + Altair chart showing class distribution. + + """ + df = pd.DataFrame({"predicted_class": preds["predicted_class"].to_numpy()}) + counts = df["predicted_class"].value_counts().reset_index() + counts.columns = ["class", "count"] + counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2) + + chart = ( + alt.Chart(counts) + .mark_bar() + .encode( + x=alt.X("class:N", title="Predicted Class"), + y=alt.Y("count:Q", title="Number of Cells"), + color=alt.Color("class:N", title="Class", legend=None), + tooltip=[ + alt.Tooltip("class:N", title="Class"), + alt.Tooltip("count:Q", title="Count"), + alt.Tooltip("percentage:Q", format=".2f", title="Percentage (%)"), + ], + ) + .properties( + width=400, + height=300, + title="Predicted Class Distribution", + ) + ) + + return chart def _plot_k_binned( @@ -740,6 +758,55 @@ def _plot_common_features(common_array: xr.DataArray): return chart +def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str, color_param: str): + """Create a scatter plot comparing two metrics with parameter-based coloring. + + Args: + results: DataFrame containing the results with metric columns. + x_metric: Name of the metric to plot on x-axis (e.g., 'precision', 'accuracy'). + y_metric: Name of the metric to plot on y-axis (e.g., 'recall', 'jaccard'). + color_param: Parameter to use for coloring ('initial_K', 'eps_cl', or 'eps_e'). + + Returns: + Altair chart showing the metric comparison. + + """ + # Determine color scale based on parameter + if color_param in ["eps_cl", "eps_e"]: + color_scale = alt.Scale(type="log", scheme="viridis") + else: + color_scale = alt.Scale(scheme="viridis") + + chart = ( + alt.Chart(results) + .mark_circle(size=60, opacity=0.7) + .encode( + x=alt.X(f"mean_test_{x_metric}:Q", title=x_metric.capitalize(), scale=alt.Scale(zero=False)), + y=alt.Y(f"mean_test_{y_metric}:Q", title=y_metric.capitalize(), scale=alt.Scale(zero=False)), + color=alt.Color( + f"{color_param}:Q", + scale=color_scale, + title=color_param, + ), + tooltip=[ + alt.Tooltip(f"mean_test_{x_metric}:Q", format=".4f", title=x_metric.capitalize()), + alt.Tooltip(f"mean_test_{y_metric}:Q", format=".4f", title=y_metric.capitalize()), + alt.Tooltip("mean_test_f1:Q", format=".4f", title="F1"), + alt.Tooltip("initial_K:Q", format=".0f", title="initial_K"), + alt.Tooltip("eps_cl:Q", format=".2e", title="eps_cl"), + alt.Tooltip("eps_e:Q", format=".2e", title="eps_e"), + ], + ) + .properties( + width=400, + height=400, + ) + .interactive() + ) + + return chart + + def main(): """Run Streamlit dashboard application.""" st.set_page_config(page_title="Training Analysis Dashboard", layout="wide") @@ -769,7 +836,8 @@ def main(): # Load and prepare data with default bin width (will be reloaded with custom width later) with st.spinner("Loading results..."): - results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=40) + settings = toml.load(results_dir / "search_settings.toml")["settings"] + results = load_and_prepare_results(results_dir / "search_results.parquet", settings, k_bin_width=40) model_state = load_and_prepare_model_state(results_dir / "best_estimator_state.nc") n_features = model_state.sizes["feature"] model_state["feature_weights"] *= n_features @@ -777,40 +845,62 @@ def main(): era5_feature_array = extract_era5_features(model_state) common_feature_array = extract_common_features(model_state) predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413") - settings = toml.load(results_dir / "search_settings.toml")["settings"] - probs = _extract_probs_as_xdggs(predictions, settings) st.sidebar.success(f"Loaded {len(results)} results") - # Metric selection - available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] - selected_metric = st.sidebar.selectbox( - "Select Metric", options=available_metrics, help="Choose which metric to visualize" - ) - # Display some basic statistics first (lightweight) - st.header("Dataset Overview") - col1, col2, col3 = st.columns(3) + st.header("Parameter-Search Overview") + + # Show total runs and best model info + col1, col2 = st.columns(2) with col1: st.metric("Total Runs", len(results)) - with col2: - best_score = results[f"mean_test_{selected_metric}"].max() - st.metric(f"Best {selected_metric.capitalize()}", f"{best_score:.4f}") - with col3: - best_idx = results[f"mean_test_{selected_metric}"].idxmax() - best_k = results.loc[best_idx, "initial_K"] - st.metric("Best K", f"{best_k:.0f}") - # Show best parameters - with st.expander("Best Parameters"): - best_idx = results[f"mean_test_{selected_metric}"].idxmax() - best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]] - st.dataframe(best_params.to_frame().T, width="content") + with col2: + # Best model based on F1 score + best_f1_idx = results["mean_test_f1"].idxmax() + st.metric("Best Model Index (by F1)", f"#{best_f1_idx}") + + # Show best parameters for the best model + with st.expander("Best Model Parameters (by F1)", expanded=True): + best_params = results.loc[best_f1_idx, ["initial_K", "eps_cl", "eps_e", "mean_test_f1", "std_test_f1"]] + col1, col2, col3, col4, col5 = st.columns(5) + with col1: + st.metric("initial_K", f"{best_params['initial_K']:.0f}") + with col2: + st.metric("eps_cl", f"{best_params['eps_cl']:.2e}") + with col3: + st.metric("eps_e", f"{best_params['eps_e']:.2e}") + with col4: + st.metric("F1 Score", f"{best_params['mean_test_f1']:.4f}") + with col5: + st.metric("F1 Std", f"{best_params['std_test_f1']:.4f}") + + # Show all metrics + available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] + cols = st.columns(len(available_metrics)) + + for idx, metric in enumerate(available_metrics): + with cols[idx]: + best_score = results.loc[best_f1_idx, f"mean_test_{metric}"] + st.metric(f"{metric.capitalize()}", f"{best_score:.4f}") # Create tabs for different visualizations - tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Predictions Map"]) + tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Inference Analysis"]) with tab1: + # Metric selection - only used in this tab + available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] + selected_metric = st.selectbox( + "Select Metric", options=available_metrics, help="Choose which metric to visualize" + ) + + # Show best parameters + with st.expander("Best Parameters"): + best_idx = results[f"mean_test_{selected_metric}"].idxmax() + best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]] + st.dataframe(best_params.to_frame().T, width="content") + # Main plots st.header(f"Visualization for {selected_metric.capitalize()}") @@ -847,7 +937,9 @@ def main(): # Reload data if bin width changed from default if k_bin_width != 40: with st.spinner("Re-binning data..."): - results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=k_bin_width) + results = load_and_prepare_results( + results_dir / "search_results.parquet", settings, k_bin_width=k_bin_width + ) # K-binned plots col1, col2 = st.columns(2) @@ -885,6 +977,30 @@ def main(): chart4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}") st.altair_chart(chart4, use_container_width=True) + # Metric comparison plots + st.header("Metric Comparisons") + + # Color parameter selection + color_param = st.selectbox( + "Select Color Parameter", + options=["initial_K", "eps_cl", "eps_e"], + help="Choose which parameter to use for coloring the scatter plots", + ) + + col1, col2 = st.columns(2) + + with col1: + st.subheader("Recall vs Precision") + with st.spinner("Generating Recall vs Precision plot..."): + recall_precision_chart = _plot_metric_comparison(results, "precision", "recall", color_param) + st.altair_chart(recall_precision_chart, use_container_width=True) + + with col2: + st.subheader("Accuracy vs Jaccard") + with st.spinner("Generating Accuracy vs Jaccard plot..."): + accuracy_jaccard_chart = _plot_metric_comparison(results, "accuracy", "jaccard", color_param) + st.altair_chart(accuracy_jaccard_chart, use_container_width=True) + # Optional: Raw data table with st.expander("View Raw Results Data"): st.dataframe(results, width="stretch") @@ -993,164 +1109,226 @@ def main(): # Embedding features analysis (if present) if embedding_feature_array is not None: - st.subheader(":artificial_satellite: Embedding Feature Analysis") - st.markdown( - """ - Analysis of embedding features showing which aggregations, bands, and years - are most important for the model predictions. - """ - ) + with st.container(border=True): + st.header(":artificial_satellite: Embedding Feature Analysis") + st.markdown( + """ + Analysis of embedding features showing which aggregations, bands, and years + are most important for the model predictions. + """ + ) - # Summary bar charts - st.markdown("### Importance by Dimension") - with st.spinner("Generating dimension summaries..."): - chart_agg, chart_band, chart_year = _plot_embedding_aggregation_summary(embedding_feature_array) - col1, col2, col3 = st.columns(3) - with col1: - st.altair_chart(chart_agg, use_container_width=True) - with col2: - st.altair_chart(chart_band, use_container_width=True) - with col3: - st.altair_chart(chart_year, use_container_width=True) + # Summary bar charts + st.markdown("### Importance by Dimension") + with st.spinner("Generating dimension summaries..."): + chart_agg, chart_band, chart_year = _plot_embedding_aggregation_summary(embedding_feature_array) + col1, col2, col3 = st.columns(3) + with col1: + st.altair_chart(chart_agg, use_container_width=True) + with col2: + st.altair_chart(chart_band, use_container_width=True) + with col3: + st.altair_chart(chart_year, use_container_width=True) - # Detailed heatmap - st.markdown("### Detailed Heatmap by Aggregation") - st.markdown("Shows the weight of each band-year combination for each aggregation type.") - with st.spinner("Generating heatmap..."): - heatmap_chart = _plot_embedding_heatmap(embedding_feature_array) - st.altair_chart(heatmap_chart, use_container_width=True) + # Detailed heatmap + st.markdown("### Detailed Heatmap by Aggregation") + st.markdown("Shows the weight of each band-year combination for each aggregation type.") + with st.spinner("Generating heatmap..."): + heatmap_chart = _plot_embedding_heatmap(embedding_feature_array) + st.altair_chart(heatmap_chart, use_container_width=True) - # Statistics - with st.expander("Embedding Feature Statistics"): - st.write("**Overall Statistics:**") - n_emb_features = embedding_feature_array.size - mean_weight = float(embedding_feature_array.mean().values) - max_weight = float(embedding_feature_array.max().values) - col1, col2, col3 = st.columns(3) - with col1: - st.metric("Total Embedding Features", n_emb_features) - with col2: - st.metric("Mean Weight", f"{mean_weight:.4f}") - with col3: - st.metric("Max Weight", f"{max_weight:.4f}") + # Statistics + with st.expander("Embedding Feature Statistics"): + st.write("**Overall Statistics:**") + n_emb_features = embedding_feature_array.size + mean_weight = float(embedding_feature_array.mean().values) + max_weight = float(embedding_feature_array.max().values) + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Total Embedding Features", n_emb_features) + with col2: + st.metric("Mean Weight", f"{mean_weight:.4f}") + with col3: + st.metric("Max Weight", f"{max_weight:.4f}") - # Show top embedding features - st.write("**Top 10 Embedding Features:**") - emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index() - top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]] - st.dataframe(top_emb, width="stretch") + # Show top embedding features + st.write("**Top 10 Embedding Features:**") + emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index() + top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]] + st.dataframe(top_emb, width="stretch") else: st.info("No embedding features found in this model.") # ERA5 features analysis (if present) if era5_feature_array is not None: - st.subheader(":partly_sunny: ERA5 Feature Analysis") - st.markdown( - """ - Analysis of ERA5 climate features showing which variables and time periods - are most important for the model predictions. - """ - ) + with st.container(border=True): + st.header(":partly_sunny: ERA5 Feature Analysis") + st.markdown( + """ + Analysis of ERA5 climate features showing which variables and time periods + are most important for the model predictions. + """ + ) - # Summary bar charts - st.markdown("### Importance by Dimension") - with st.spinner("Generating ERA5 dimension summaries..."): - chart_variable, chart_time = _plot_era5_summary(era5_feature_array) - col1, col2 = st.columns(2) - with col1: - st.altair_chart(chart_variable, use_container_width=True) - with col2: - st.altair_chart(chart_time, use_container_width=True) + # Summary bar charts + st.markdown("### Importance by Dimension") + with st.spinner("Generating ERA5 dimension summaries..."): + chart_variable, chart_time = _plot_era5_summary(era5_feature_array) + col1, col2 = st.columns(2) + with col1: + st.altair_chart(chart_variable, use_container_width=True) + with col2: + st.altair_chart(chart_time, use_container_width=True) - # Detailed heatmap - st.markdown("### Detailed Heatmap") - st.markdown("Shows the weight of each variable-time combination.") - with st.spinner("Generating ERA5 heatmap..."): - era5_heatmap_chart = _plot_era5_heatmap(era5_feature_array) - st.altair_chart(era5_heatmap_chart, use_container_width=True) + # Detailed heatmap + st.markdown("### Detailed Heatmap") + st.markdown("Shows the weight of each variable-time combination.") + with st.spinner("Generating ERA5 heatmap..."): + era5_heatmap_chart = _plot_era5_heatmap(era5_feature_array) + st.altair_chart(era5_heatmap_chart, use_container_width=True) - # Statistics - with st.expander("ERA5 Feature Statistics"): - st.write("**Overall Statistics:**") - n_era5_features = era5_feature_array.size - mean_weight = float(era5_feature_array.mean().values) - max_weight = float(era5_feature_array.max().values) - col1, col2, col3 = st.columns(3) - with col1: - st.metric("Total ERA5 Features", n_era5_features) - with col2: - st.metric("Mean Weight", f"{mean_weight:.4f}") - with col3: - st.metric("Max Weight", f"{max_weight:.4f}") + # Statistics + with st.expander("ERA5 Feature Statistics"): + st.write("**Overall Statistics:**") + n_era5_features = era5_feature_array.size + mean_weight = float(era5_feature_array.mean().values) + max_weight = float(era5_feature_array.max().values) + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Total ERA5 Features", n_era5_features) + with col2: + st.metric("Mean Weight", f"{mean_weight:.4f}") + with col3: + st.metric("Max Weight", f"{max_weight:.4f}") - # Show top ERA5 features - st.write("**Top 10 ERA5 Features:**") - era5_df = era5_feature_array.to_dataframe(name="weight").reset_index() - top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]] - st.dataframe(top_era5, width="stretch") + # Show top ERA5 features + st.write("**Top 10 ERA5 Features:**") + era5_df = era5_feature_array.to_dataframe(name="weight").reset_index() + top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]] + st.dataframe(top_era5, width="stretch") else: st.info("No ERA5 features found in this model.") # Common features analysis (if present) if common_feature_array is not None: - st.subheader(":world_map: Common Feature Analysis") - st.markdown( - """ - Analysis of common features including cell area, water area, land area, land ratio, - longitude, and latitude. These features provide spatial and geographic context. - """ - ) + with st.container(border=True): + st.header(":world_map: Common Feature Analysis") + st.markdown( + """ + Analysis of common features including cell area, water area, land area, land ratio, + longitude, and latitude. These features provide spatial and geographic context. + """ + ) - # Bar chart showing all common feature weights - with st.spinner("Generating common features chart..."): - common_chart = _plot_common_features(common_feature_array) - st.altair_chart(common_chart, use_container_width=True) + # Bar chart showing all common feature weights + with st.spinner("Generating common features chart..."): + common_chart = _plot_common_features(common_feature_array) + st.altair_chart(common_chart, use_container_width=True) - # Statistics - with st.expander("Common Feature Statistics"): - st.write("**Overall Statistics:**") - n_common_features = common_feature_array.size - mean_weight = float(common_feature_array.mean().values) - max_weight = float(common_feature_array.max().values) - min_weight = float(common_feature_array.min().values) - col1, col2, col3, col4 = st.columns(4) - with col1: - st.metric("Total Common Features", n_common_features) - with col2: - st.metric("Mean Weight", f"{mean_weight:.4f}") - with col3: - st.metric("Max Weight", f"{max_weight:.4f}") - with col4: - st.metric("Min Weight", f"{min_weight:.4f}") + # Statistics + with st.expander("Common Feature Statistics"): + st.write("**Overall Statistics:**") + n_common_features = common_feature_array.size + mean_weight = float(common_feature_array.mean().values) + max_weight = float(common_feature_array.max().values) + min_weight = float(common_feature_array.min().values) + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("Total Common Features", n_common_features) + with col2: + st.metric("Mean Weight", f"{mean_weight:.4f}") + with col3: + st.metric("Max Weight", f"{max_weight:.4f}") + with col4: + st.metric("Min Weight", f"{min_weight:.4f}") - # Show all common features sorted by importance - st.write("**All Common Features (by absolute weight):**") - common_df = common_feature_array.to_dataframe(name="weight").reset_index() - common_df["abs_weight"] = common_df["weight"].abs() - common_df = common_df.sort_values("abs_weight", ascending=False) - st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch") + # Show all common features sorted by importance + st.write("**All Common Features (by absolute weight):**") + common_df = common_feature_array.to_dataframe(name="weight").reset_index() + common_df["abs_weight"] = common_df["weight"].abs() + common_df = common_df.sort_values("abs_weight", ascending=False) + st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch") - st.markdown( - """ - **Interpretation:** - - **cell_area, water_area, land_area**: Spatial extent features that may indicate size-related patterns - - **land_ratio**: Proportion of land vs water in each cell - - **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns - - Positive weights indicate features that increase the probability of the positive class - - Negative weights indicate features that decrease the probability of the positive class - """ - ) + st.markdown( + """ + **Interpretation:** + - **cell_area, water_area, land_area**: Spatial extent features that may indicate + size-related patterns + - **land_ratio**: Proportion of land vs water in each cell + - **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns + - Positive weights indicate features that increase the probability of the positive class + - Negative weights indicate features that decrease the probability of the positive class + """ + ) else: st.info("No common features found in this model.") with tab3: - # Map visualization - st.header("Predictions Map") - st.markdown("Map showing predicted classes from the best estimator") - with st.spinner("Generating map..."): + # Inference analysis + st.header("Inference Analysis") + st.markdown("Comprehensive analysis of model predictions on the evaluation dataset") + + # Summary statistics + st.subheader("Prediction Statistics") + col1, col2 = st.columns(2) + + with col1: + total_cells = len(predictions) + st.metric("Total Cells", f"{total_cells:,}") + + with col2: + n_classes = predictions["predicted_class"].nunique() + st.metric("Number of Classes", n_classes) + + # Class distribution visualization + st.subheader("Class Distribution") + + with st.spinner("Generating class distribution..."): + class_dist_chart = _plot_prediction_class_distribution(predictions) + st.altair_chart(class_dist_chart, use_container_width=True) + + st.markdown( + """ + **Interpretation:** + - Shows the balance between predicted classes + - Class imbalance may indicate regional patterns or model bias + - Each bar represents the count of cells predicted for that class + """ + ) + + # Interactive map + st.subheader("Interactive Prediction Map") + st.markdown("Explore predictions spatially with the interactive map below") + + with st.spinner("Generating interactive map..."): chart_map = _plot_prediction_map(predictions) st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[]) + # Additional statistics in expander + with st.expander("Detailed Prediction Statistics"): + st.write("**Class Distribution:**") + class_counts = predictions["predicted_class"].value_counts().sort_index() + + # Create columns for better layout + n_cols = min(5, len(class_counts)) + cols = st.columns(n_cols) + + for idx, (class_label, count) in enumerate(class_counts.items()): + percentage = count / len(predictions) * 100 + with cols[idx % n_cols]: + st.metric(f"Class {class_label}", f"{count:,} ({percentage:.2f}%)") + + # Show detailed table + st.write("**Detailed Class Breakdown:**") + class_df = pd.DataFrame( + { + "Class": class_counts.index, + "Count": class_counts.to_numpy(), + "Percentage": (class_counts.to_numpy() / len(predictions) * 100).round(2), + } + ) + st.dataframe(class_df, width="stretch", hide_index=True) + if __name__ == "__main__": main()