From d4a747d80066f08c0f0f5decd9e55dc98a90afd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sun, 9 Nov 2025 00:27:42 +0100 Subject: [PATCH] Fix Map and add common feature analysis --- src/entropice/training_analysis_dashboard.py | 178 ++++++++++++++++--- 1 file changed, 157 insertions(+), 21 deletions(-) diff --git a/src/entropice/training_analysis_dashboard.py b/src/entropice/training_analysis_dashboard.py index 5029c17..605b532 100644 --- a/src/entropice/training_analysis_dashboard.py +++ b/src/entropice/training_analysis_dashboard.py @@ -4,13 +4,14 @@ from datetime import datetime from pathlib import Path import altair as alt -import folium import geopandas as gpd import numpy as np import pandas as pd import streamlit as st +import streamlit_folium as st_folium +import toml import xarray as xr -from streamlit_folium import st_folium +import xdggs from entropice.paths import RESULTS_DIR @@ -158,21 +159,54 @@ def extract_era5_features(model_state: xr.Dataset) -> xr.DataArray | None: return era5_features_array -# TODO: Extract common features, e.g. area or water content - - -def _plot_prediction_map(preds: gpd.GeoDataFrame) -> folium.Map: - """Plot predicted probabilities on a map using Streamlit and Folium. +def extract_common_features(model_state: xr.Dataset) -> xr.DataArray | None: + """Extract common features (cell_area, water_area, land_area, land_ratio, lon, lat) from the model state. Args: - preds: GeoDataFrame containing 'predicted_proba' and 'geometry' columns. + model_state: The xarray Dataset containing the model state. Returns: - folium.Map: A Folium map object with the predicted probabilities visualized. + xr.DataArray: The extracted common features with a single 'feature' dimension. + Returns None if no common features are found. """ - m = preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron") - return m + common_feature_names = ["cell_area", "water_area", "land_area", "land_ratio", "lon", "lat"] + + def _is_common_feature(feature: str) -> bool: + return feature in common_feature_names + + common_features = [f for f in model_state.feature.to_numpy() if _is_common_feature(f)] + if len(common_features) == 0: + return None + + # Extract the feature weights for common features + common_feature_array = model_state.sel(feature=common_features)["feature_weights"] + return common_feature_array + + +def _extract_probs_as_xdggs(preds: gpd.GeoDataFrame, settings: dict) -> xr.DataArray: + grid = settings["grid"] + level = settings["level"] + + probs = xr.DataArray( + data=preds["predicted_proba"].to_numpy(), + dims=["cell_ids"], + coords={"cell_ids": preds["cell_id"].to_numpy()}, + ) + gridinfo = { + "grid_name": "h3" if grid == "hex" else grid, + "level": level, + } + if grid == "healpix": + gridinfo["indexing_scheme"] = "nested" + probs.cell_ids.attrs = gridinfo + probs = xdggs.decode(probs) + return probs + + +def _plot_prediction_map(preds: gpd.GeoDataFrame): + return preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron") + # return probs.dggs.explore() def _plot_k_binned( @@ -661,6 +695,51 @@ def _plot_box_assignment_bars(model_state: xr.Dataset): return chart +def _plot_common_features(common_array: xr.DataArray): + """Create a bar chart showing the weights of common features. + + Args: + common_array: DataArray with dimension (feature) containing feature weights. + + Returns: + Altair chart showing the common feature weights. + + """ + # Convert to DataFrame for plotting + df = common_array.to_dataframe(name="weight").reset_index() + + # Sort by absolute weight + df["abs_weight"] = df["weight"].abs() + df = df.sort_values("abs_weight", ascending=True) + + # Create bar chart + chart = ( + alt.Chart(df) + .mark_bar() + .encode( + y=alt.Y("feature:N", title="Feature", sort="-x"), + x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"), + color=alt.condition( + alt.datum.weight > 0, + alt.value("steelblue"), # Positive weights + alt.value("coral"), # Negative weights + ), + tooltip=[ + alt.Tooltip("feature:N", title="Feature"), + alt.Tooltip("weight:Q", format=".4f", title="Weight"), + alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"), + ], + ) + .properties( + width=600, + height=300, + title="Common Feature Weights", + ) + ) + + return chart + + def main(): """Run Streamlit dashboard application.""" st.set_page_config(page_title="Training Analysis Dashboard", layout="wide") @@ -696,7 +775,10 @@ def main(): model_state["feature_weights"] *= n_features embedding_feature_array = extract_embedding_features(model_state) era5_feature_array = extract_era5_features(model_state) + common_feature_array = extract_common_features(model_state) predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413") + settings = toml.load(results_dir / "search_settings.toml")["settings"] + probs = _extract_probs_as_xdggs(predictions, settings) st.sidebar.success(f"Loaded {len(results)} results") @@ -706,13 +788,7 @@ def main(): "Select Metric", options=available_metrics, help="Choose which metric to visualize" ) - # Map visualization - st.header("Predictions Map") - st.markdown("Map showing predicted classes from the best estimator") - m = _plot_prediction_map(predictions) - st_folium(m, width="100%", height=300) - - # Display some basic statistics + # Display some basic statistics first (lightweight) st.header("Dataset Overview") col1, col2, col3 = st.columns(3) with col1: @@ -732,7 +808,7 @@ def main(): st.dataframe(best_params.to_frame().T, width="content") # Create tabs for different visualizations - tab1, tab2 = st.tabs(["Search Results", "Model State"]) + tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Predictions Map"]) with tab1: # Main plots @@ -917,7 +993,7 @@ def main(): # Embedding features analysis (if present) if embedding_feature_array is not None: - st.subheader("Embedding Feature Analysis") + st.subheader(":artificial_satellite: Embedding Feature Analysis") st.markdown( """ Analysis of embedding features showing which aggregations, bands, and years @@ -968,7 +1044,7 @@ def main(): # ERA5 features analysis (if present) if era5_feature_array is not None: - st.subheader("ERA5 Feature Analysis") + st.subheader(":partly_sunny: ERA5 Feature Analysis") st.markdown( """ Analysis of ERA5 climate features showing which variables and time periods @@ -1015,6 +1091,66 @@ def main(): else: st.info("No ERA5 features found in this model.") + # Common features analysis (if present) + if common_feature_array is not None: + st.subheader(":world_map: Common Feature Analysis") + st.markdown( + """ + Analysis of common features including cell area, water area, land area, land ratio, + longitude, and latitude. These features provide spatial and geographic context. + """ + ) + + # Bar chart showing all common feature weights + with st.spinner("Generating common features chart..."): + common_chart = _plot_common_features(common_feature_array) + st.altair_chart(common_chart, use_container_width=True) + + # Statistics + with st.expander("Common Feature Statistics"): + st.write("**Overall Statistics:**") + n_common_features = common_feature_array.size + mean_weight = float(common_feature_array.mean().values) + max_weight = float(common_feature_array.max().values) + min_weight = float(common_feature_array.min().values) + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("Total Common Features", n_common_features) + with col2: + st.metric("Mean Weight", f"{mean_weight:.4f}") + with col3: + st.metric("Max Weight", f"{max_weight:.4f}") + with col4: + st.metric("Min Weight", f"{min_weight:.4f}") + + # Show all common features sorted by importance + st.write("**All Common Features (by absolute weight):**") + common_df = common_feature_array.to_dataframe(name="weight").reset_index() + common_df["abs_weight"] = common_df["weight"].abs() + common_df = common_df.sort_values("abs_weight", ascending=False) + st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch") + + st.markdown( + """ + **Interpretation:** + - **cell_area, water_area, land_area**: Spatial extent features that may indicate size-related patterns + - **land_ratio**: Proportion of land vs water in each cell + - **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns + - Positive weights indicate features that increase the probability of the positive class + - Negative weights indicate features that decrease the probability of the positive class + """ + ) + else: + st.info("No common features found in this model.") + + with tab3: + # Map visualization + st.header("Predictions Map") + st.markdown("Map showing predicted classes from the best estimator") + with st.spinner("Generating map..."): + chart_map = _plot_prediction_map(predictions) + st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[]) + if __name__ == "__main__": main()