Fix Map and add common feature analysis

This commit is contained in:
Tobias Hölzer 2025-11-09 00:27:42 +01:00
parent fb522ddad5
commit d4a747d800

View file

@ -4,13 +4,14 @@ from datetime import datetime
from pathlib import Path
import altair as alt
import folium
import geopandas as gpd
import numpy as np
import pandas as pd
import streamlit as st
import streamlit_folium as st_folium
import toml
import xarray as xr
from streamlit_folium import st_folium
import xdggs
from entropice.paths import RESULTS_DIR
@ -158,21 +159,54 @@ def extract_era5_features(model_state: xr.Dataset) -> xr.DataArray | None:
return era5_features_array
# TODO: Extract common features, e.g. area or water content
def _plot_prediction_map(preds: gpd.GeoDataFrame) -> folium.Map:
"""Plot predicted probabilities on a map using Streamlit and Folium.
def extract_common_features(model_state: xr.Dataset) -> xr.DataArray | None:
"""Extract common features (cell_area, water_area, land_area, land_ratio, lon, lat) from the model state.
Args:
preds: GeoDataFrame containing 'predicted_proba' and 'geometry' columns.
model_state: The xarray Dataset containing the model state.
Returns:
folium.Map: A Folium map object with the predicted probabilities visualized.
xr.DataArray: The extracted common features with a single 'feature' dimension.
Returns None if no common features are found.
"""
m = preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron")
return m
common_feature_names = ["cell_area", "water_area", "land_area", "land_ratio", "lon", "lat"]
def _is_common_feature(feature: str) -> bool:
return feature in common_feature_names
common_features = [f for f in model_state.feature.to_numpy() if _is_common_feature(f)]
if len(common_features) == 0:
return None
# Extract the feature weights for common features
common_feature_array = model_state.sel(feature=common_features)["feature_weights"]
return common_feature_array
def _extract_probs_as_xdggs(preds: gpd.GeoDataFrame, settings: dict) -> xr.DataArray:
grid = settings["grid"]
level = settings["level"]
probs = xr.DataArray(
data=preds["predicted_proba"].to_numpy(),
dims=["cell_ids"],
coords={"cell_ids": preds["cell_id"].to_numpy()},
)
gridinfo = {
"grid_name": "h3" if grid == "hex" else grid,
"level": level,
}
if grid == "healpix":
gridinfo["indexing_scheme"] = "nested"
probs.cell_ids.attrs = gridinfo
probs = xdggs.decode(probs)
return probs
def _plot_prediction_map(preds: gpd.GeoDataFrame):
return preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron")
# return probs.dggs.explore()
def _plot_k_binned(
@ -661,6 +695,51 @@ def _plot_box_assignment_bars(model_state: xr.Dataset):
return chart
def _plot_common_features(common_array: xr.DataArray):
"""Create a bar chart showing the weights of common features.
Args:
common_array: DataArray with dimension (feature) containing feature weights.
Returns:
Altair chart showing the common feature weights.
"""
# Convert to DataFrame for plotting
df = common_array.to_dataframe(name="weight").reset_index()
# Sort by absolute weight
df["abs_weight"] = df["weight"].abs()
df = df.sort_values("abs_weight", ascending=True)
# Create bar chart
chart = (
alt.Chart(df)
.mark_bar()
.encode(
y=alt.Y("feature:N", title="Feature", sort="-x"),
x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"),
color=alt.condition(
alt.datum.weight > 0,
alt.value("steelblue"), # Positive weights
alt.value("coral"), # Negative weights
),
tooltip=[
alt.Tooltip("feature:N", title="Feature"),
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"),
],
)
.properties(
width=600,
height=300,
title="Common Feature Weights",
)
)
return chart
def main():
"""Run Streamlit dashboard application."""
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
@ -696,7 +775,10 @@ def main():
model_state["feature_weights"] *= n_features
embedding_feature_array = extract_embedding_features(model_state)
era5_feature_array = extract_era5_features(model_state)
common_feature_array = extract_common_features(model_state)
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
settings = toml.load(results_dir / "search_settings.toml")["settings"]
probs = _extract_probs_as_xdggs(predictions, settings)
st.sidebar.success(f"Loaded {len(results)} results")
@ -706,13 +788,7 @@ def main():
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
)
# Map visualization
st.header("Predictions Map")
st.markdown("Map showing predicted classes from the best estimator")
m = _plot_prediction_map(predictions)
st_folium(m, width="100%", height=300)
# Display some basic statistics
# Display some basic statistics first (lightweight)
st.header("Dataset Overview")
col1, col2, col3 = st.columns(3)
with col1:
@ -732,7 +808,7 @@ def main():
st.dataframe(best_params.to_frame().T, width="content")
# Create tabs for different visualizations
tab1, tab2 = st.tabs(["Search Results", "Model State"])
tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Predictions Map"])
with tab1:
# Main plots
@ -917,7 +993,7 @@ def main():
# Embedding features analysis (if present)
if embedding_feature_array is not None:
st.subheader("Embedding Feature Analysis")
st.subheader(":artificial_satellite: Embedding Feature Analysis")
st.markdown(
"""
Analysis of embedding features showing which aggregations, bands, and years
@ -968,7 +1044,7 @@ def main():
# ERA5 features analysis (if present)
if era5_feature_array is not None:
st.subheader("ERA5 Feature Analysis")
st.subheader(":partly_sunny: ERA5 Feature Analysis")
st.markdown(
"""
Analysis of ERA5 climate features showing which variables and time periods
@ -1015,6 +1091,66 @@ def main():
else:
st.info("No ERA5 features found in this model.")
# Common features analysis (if present)
if common_feature_array is not None:
st.subheader(":world_map: Common Feature Analysis")
st.markdown(
"""
Analysis of common features including cell area, water area, land area, land ratio,
longitude, and latitude. These features provide spatial and geographic context.
"""
)
# Bar chart showing all common feature weights
with st.spinner("Generating common features chart..."):
common_chart = _plot_common_features(common_feature_array)
st.altair_chart(common_chart, use_container_width=True)
# Statistics
with st.expander("Common Feature Statistics"):
st.write("**Overall Statistics:**")
n_common_features = common_feature_array.size
mean_weight = float(common_feature_array.mean().values)
max_weight = float(common_feature_array.max().values)
min_weight = float(common_feature_array.min().values)
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Common Features", n_common_features)
with col2:
st.metric("Mean Weight", f"{mean_weight:.4f}")
with col3:
st.metric("Max Weight", f"{max_weight:.4f}")
with col4:
st.metric("Min Weight", f"{min_weight:.4f}")
# Show all common features sorted by importance
st.write("**All Common Features (by absolute weight):**")
common_df = common_feature_array.to_dataframe(name="weight").reset_index()
common_df["abs_weight"] = common_df["weight"].abs()
common_df = common_df.sort_values("abs_weight", ascending=False)
st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch")
st.markdown(
"""
**Interpretation:**
- **cell_area, water_area, land_area**: Spatial extent features that may indicate size-related patterns
- **land_ratio**: Proportion of land vs water in each cell
- **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
- Positive weights indicate features that increase the probability of the positive class
- Negative weights indicate features that decrease the probability of the positive class
"""
)
else:
st.info("No common features found in this model.")
with tab3:
# Map visualization
st.header("Predictions Map")
st.markdown("Map showing predicted classes from the best estimator")
with st.spinner("Generating map..."):
chart_map = _plot_prediction_map(predictions)
st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[])
if __name__ == "__main__":
main()