Fix Map and add common feature analysis
This commit is contained in:
parent
fb522ddad5
commit
d4a747d800
1 changed files with 157 additions and 21 deletions
|
|
@ -4,13 +4,14 @@ from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import altair as alt
|
import altair as alt
|
||||||
import folium
|
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
import streamlit_folium as st_folium
|
||||||
|
import toml
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from streamlit_folium import st_folium
|
import xdggs
|
||||||
|
|
||||||
from entropice.paths import RESULTS_DIR
|
from entropice.paths import RESULTS_DIR
|
||||||
|
|
||||||
|
|
@ -158,21 +159,54 @@ def extract_era5_features(model_state: xr.Dataset) -> xr.DataArray | None:
|
||||||
return era5_features_array
|
return era5_features_array
|
||||||
|
|
||||||
|
|
||||||
# TODO: Extract common features, e.g. area or water content
|
def extract_common_features(model_state: xr.Dataset) -> xr.DataArray | None:
|
||||||
|
"""Extract common features (cell_area, water_area, land_area, land_ratio, lon, lat) from the model state.
|
||||||
|
|
||||||
def _plot_prediction_map(preds: gpd.GeoDataFrame) -> folium.Map:
|
|
||||||
"""Plot predicted probabilities on a map using Streamlit and Folium.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
preds: GeoDataFrame containing 'predicted_proba' and 'geometry' columns.
|
model_state: The xarray Dataset containing the model state.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
folium.Map: A Folium map object with the predicted probabilities visualized.
|
xr.DataArray: The extracted common features with a single 'feature' dimension.
|
||||||
|
Returns None if no common features are found.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
m = preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron")
|
common_feature_names = ["cell_area", "water_area", "land_area", "land_ratio", "lon", "lat"]
|
||||||
return m
|
|
||||||
|
def _is_common_feature(feature: str) -> bool:
|
||||||
|
return feature in common_feature_names
|
||||||
|
|
||||||
|
common_features = [f for f in model_state.feature.to_numpy() if _is_common_feature(f)]
|
||||||
|
if len(common_features) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract the feature weights for common features
|
||||||
|
common_feature_array = model_state.sel(feature=common_features)["feature_weights"]
|
||||||
|
return common_feature_array
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_probs_as_xdggs(preds: gpd.GeoDataFrame, settings: dict) -> xr.DataArray:
|
||||||
|
grid = settings["grid"]
|
||||||
|
level = settings["level"]
|
||||||
|
|
||||||
|
probs = xr.DataArray(
|
||||||
|
data=preds["predicted_proba"].to_numpy(),
|
||||||
|
dims=["cell_ids"],
|
||||||
|
coords={"cell_ids": preds["cell_id"].to_numpy()},
|
||||||
|
)
|
||||||
|
gridinfo = {
|
||||||
|
"grid_name": "h3" if grid == "hex" else grid,
|
||||||
|
"level": level,
|
||||||
|
}
|
||||||
|
if grid == "healpix":
|
||||||
|
gridinfo["indexing_scheme"] = "nested"
|
||||||
|
probs.cell_ids.attrs = gridinfo
|
||||||
|
probs = xdggs.decode(probs)
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_prediction_map(preds: gpd.GeoDataFrame):
|
||||||
|
return preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron")
|
||||||
|
# return probs.dggs.explore()
|
||||||
|
|
||||||
|
|
||||||
def _plot_k_binned(
|
def _plot_k_binned(
|
||||||
|
|
@ -661,6 +695,51 @@ def _plot_box_assignment_bars(model_state: xr.Dataset):
|
||||||
return chart
|
return chart
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_common_features(common_array: xr.DataArray):
|
||||||
|
"""Create a bar chart showing the weights of common features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
common_array: DataArray with dimension (feature) containing feature weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Altair chart showing the common feature weights.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Convert to DataFrame for plotting
|
||||||
|
df = common_array.to_dataframe(name="weight").reset_index()
|
||||||
|
|
||||||
|
# Sort by absolute weight
|
||||||
|
df["abs_weight"] = df["weight"].abs()
|
||||||
|
df = df.sort_values("abs_weight", ascending=True)
|
||||||
|
|
||||||
|
# Create bar chart
|
||||||
|
chart = (
|
||||||
|
alt.Chart(df)
|
||||||
|
.mark_bar()
|
||||||
|
.encode(
|
||||||
|
y=alt.Y("feature:N", title="Feature", sort="-x"),
|
||||||
|
x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"),
|
||||||
|
color=alt.condition(
|
||||||
|
alt.datum.weight > 0,
|
||||||
|
alt.value("steelblue"), # Positive weights
|
||||||
|
alt.value("coral"), # Negative weights
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip("feature:N", title="Feature"),
|
||||||
|
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
|
||||||
|
alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(
|
||||||
|
width=600,
|
||||||
|
height=300,
|
||||||
|
title="Common Feature Weights",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chart
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Run Streamlit dashboard application."""
|
"""Run Streamlit dashboard application."""
|
||||||
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
|
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
|
||||||
|
|
@ -696,7 +775,10 @@ def main():
|
||||||
model_state["feature_weights"] *= n_features
|
model_state["feature_weights"] *= n_features
|
||||||
embedding_feature_array = extract_embedding_features(model_state)
|
embedding_feature_array = extract_embedding_features(model_state)
|
||||||
era5_feature_array = extract_era5_features(model_state)
|
era5_feature_array = extract_era5_features(model_state)
|
||||||
|
common_feature_array = extract_common_features(model_state)
|
||||||
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
|
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
|
||||||
|
settings = toml.load(results_dir / "search_settings.toml")["settings"]
|
||||||
|
probs = _extract_probs_as_xdggs(predictions, settings)
|
||||||
|
|
||||||
st.sidebar.success(f"Loaded {len(results)} results")
|
st.sidebar.success(f"Loaded {len(results)} results")
|
||||||
|
|
||||||
|
|
@ -706,13 +788,7 @@ def main():
|
||||||
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
|
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map visualization
|
# Display some basic statistics first (lightweight)
|
||||||
st.header("Predictions Map")
|
|
||||||
st.markdown("Map showing predicted classes from the best estimator")
|
|
||||||
m = _plot_prediction_map(predictions)
|
|
||||||
st_folium(m, width="100%", height=300)
|
|
||||||
|
|
||||||
# Display some basic statistics
|
|
||||||
st.header("Dataset Overview")
|
st.header("Dataset Overview")
|
||||||
col1, col2, col3 = st.columns(3)
|
col1, col2, col3 = st.columns(3)
|
||||||
with col1:
|
with col1:
|
||||||
|
|
@ -732,7 +808,7 @@ def main():
|
||||||
st.dataframe(best_params.to_frame().T, width="content")
|
st.dataframe(best_params.to_frame().T, width="content")
|
||||||
|
|
||||||
# Create tabs for different visualizations
|
# Create tabs for different visualizations
|
||||||
tab1, tab2 = st.tabs(["Search Results", "Model State"])
|
tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Predictions Map"])
|
||||||
|
|
||||||
with tab1:
|
with tab1:
|
||||||
# Main plots
|
# Main plots
|
||||||
|
|
@ -917,7 +993,7 @@ def main():
|
||||||
|
|
||||||
# Embedding features analysis (if present)
|
# Embedding features analysis (if present)
|
||||||
if embedding_feature_array is not None:
|
if embedding_feature_array is not None:
|
||||||
st.subheader("Embedding Feature Analysis")
|
st.subheader(":artificial_satellite: Embedding Feature Analysis")
|
||||||
st.markdown(
|
st.markdown(
|
||||||
"""
|
"""
|
||||||
Analysis of embedding features showing which aggregations, bands, and years
|
Analysis of embedding features showing which aggregations, bands, and years
|
||||||
|
|
@ -968,7 +1044,7 @@ def main():
|
||||||
|
|
||||||
# ERA5 features analysis (if present)
|
# ERA5 features analysis (if present)
|
||||||
if era5_feature_array is not None:
|
if era5_feature_array is not None:
|
||||||
st.subheader("ERA5 Feature Analysis")
|
st.subheader(":partly_sunny: ERA5 Feature Analysis")
|
||||||
st.markdown(
|
st.markdown(
|
||||||
"""
|
"""
|
||||||
Analysis of ERA5 climate features showing which variables and time periods
|
Analysis of ERA5 climate features showing which variables and time periods
|
||||||
|
|
@ -1015,6 +1091,66 @@ def main():
|
||||||
else:
|
else:
|
||||||
st.info("No ERA5 features found in this model.")
|
st.info("No ERA5 features found in this model.")
|
||||||
|
|
||||||
|
# Common features analysis (if present)
|
||||||
|
if common_feature_array is not None:
|
||||||
|
st.subheader(":world_map: Common Feature Analysis")
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
Analysis of common features including cell area, water area, land area, land ratio,
|
||||||
|
longitude, and latitude. These features provide spatial and geographic context.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bar chart showing all common feature weights
|
||||||
|
with st.spinner("Generating common features chart..."):
|
||||||
|
common_chart = _plot_common_features(common_feature_array)
|
||||||
|
st.altair_chart(common_chart, use_container_width=True)
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
with st.expander("Common Feature Statistics"):
|
||||||
|
st.write("**Overall Statistics:**")
|
||||||
|
n_common_features = common_feature_array.size
|
||||||
|
mean_weight = float(common_feature_array.mean().values)
|
||||||
|
max_weight = float(common_feature_array.max().values)
|
||||||
|
min_weight = float(common_feature_array.min().values)
|
||||||
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
with col1:
|
||||||
|
st.metric("Total Common Features", n_common_features)
|
||||||
|
with col2:
|
||||||
|
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||||
|
with col3:
|
||||||
|
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||||
|
with col4:
|
||||||
|
st.metric("Min Weight", f"{min_weight:.4f}")
|
||||||
|
|
||||||
|
# Show all common features sorted by importance
|
||||||
|
st.write("**All Common Features (by absolute weight):**")
|
||||||
|
common_df = common_feature_array.to_dataframe(name="weight").reset_index()
|
||||||
|
common_df["abs_weight"] = common_df["weight"].abs()
|
||||||
|
common_df = common_df.sort_values("abs_weight", ascending=False)
|
||||||
|
st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
**Interpretation:**
|
||||||
|
- **cell_area, water_area, land_area**: Spatial extent features that may indicate size-related patterns
|
||||||
|
- **land_ratio**: Proportion of land vs water in each cell
|
||||||
|
- **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
|
||||||
|
- Positive weights indicate features that increase the probability of the positive class
|
||||||
|
- Negative weights indicate features that decrease the probability of the positive class
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
st.info("No common features found in this model.")
|
||||||
|
|
||||||
|
with tab3:
|
||||||
|
# Map visualization
|
||||||
|
st.header("Predictions Map")
|
||||||
|
st.markdown("Map showing predicted classes from the best estimator")
|
||||||
|
with st.spinner("Generating map..."):
|
||||||
|
chart_map = _plot_prediction_map(predictions)
|
||||||
|
st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue