Switch to predict_classes

This commit is contained in:
Tobias Hölzer 2025-11-09 01:39:36 +01:00
parent d4a747d800
commit b2cfddfead
3 changed files with 373 additions and 192 deletions

View file

@ -18,13 +18,14 @@ pretty.install()
set_config(array_api_dispatch=True) set_config(array_api_dispatch=True)
def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier): def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier, classes: list) -> gpd.GeoDataFrame:
"""Get predicted probabilities for each cell. """Get predicted probabilities for each cell.
Args: Args:
grid (Literal["hex", "healpix"]): The grid type to use. grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use. level (int): The grid level to use.
clf (ESPAClassifier): The trained classifier to use for predictions. clf (ESPAClassifier): The trained classifier to use for predictions.
classes (list): List of class names.
Returns: Returns:
list: A list of predicted probabilities for each cell. list: A list of predicted probabilities for each cell.
@ -46,11 +47,11 @@ def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifi
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy() cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
X_batch = X_batch.to_numpy(dtype="float32") X_batch = X_batch.to_numpy(dtype="float32")
X_batch = torch.asarray(X_batch, device=0) X_batch = torch.asarray(X_batch, device=0)
batch_preds = clf.predict_proba(X_batch)[:, 1].cpu().numpy() batch_preds = clf.predict(X_batch).cpu().numpy()
batch_preds = gpd.GeoDataFrame( batch_preds = gpd.GeoDataFrame(
{ {
"cell_id": cell_ids, "cell_id": cell_ids,
"predicted_proba": batch_preds, "predicted_class": [classes[i] for i in batch_preds],
"geometry": cell_geoms, "geometry": cell_geoms,
}, },
crs="epsg:3413", crs="epsg:3413",

View file

@ -29,13 +29,14 @@ pretty.install()
set_config(array_api_dispatch=True) set_config(array_api_dispatch=True)
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000): def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, robust: bool = False):
"""Perform random cross-validation on the training dataset. """Perform random cross-validation on the training dataset.
Args: Args:
grid (Literal["hex", "healpix"]): The grid type to use. grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use. level (int): The grid level to use.
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000. n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
robust (bool, optional): Whether to use robust training. Defaults to False.
""" """
data = get_train_dataset_file(grid=grid, level=level) data = get_train_dataset_file(grid=grid, level=level)
@ -59,7 +60,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
"initial_K": randint(20, 400), "initial_K": randint(20, 400),
} }
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42) clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
cv = KFold(n_splits=5, shuffle=True, random_state=42) cv = KFold(n_splits=5, shuffle=True, random_state=42)
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
search = RandomizedSearchCV( search = RandomizedSearchCV(
@ -69,7 +70,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
n_jobs=16, n_jobs=16,
cv=cv, cv=cv,
random_state=42, random_state=42,
verbose=10, verbose=3,
scoring=metrics, scoring=metrics,
refit="f1", refit="f1",
) )
@ -114,6 +115,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
}, },
"cv_splits": cv.get_n_splits(), "cv_splits": cv.get_n_splits(),
"metrics": metrics, "metrics": metrics,
"classes": ["No RTS", "RTS"],
} }
settings_file = results_dir / "search_settings.toml" settings_file = results_dir / "search_settings.toml"
print(f"Storing search settings to {settings_file}") print(f"Storing search settings to {settings_file}")
@ -183,7 +185,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
# Predict probabilities for all cells # Predict probabilities for all cells
print("Predicting probabilities for all cells...") print("Predicting probabilities for all cells...")
preds = predict_proba(grid=grid, level=level, clf=best_estimator) preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=settings["classes"])
preds_file = results_dir / "predicted_probabilities.parquet" preds_file = results_dir / "predicted_probabilities.parquet"
print(f"Storing predicted probabilities to {preds_file}") print(f"Storing predicted probabilities to {preds_file}")
preds.to_parquet(preds_file) preds.to_parquet(preds_file)

View file

@ -36,11 +36,12 @@ def get_available_result_files() -> list[Path]:
return sorted(result_files, reverse=True) # Most recent first return sorted(result_files, reverse=True) # Most recent first
def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame: def load_and_prepare_results(file_path: Path, settings: dict, k_bin_width: int = 40) -> pd.DataFrame:
"""Load results file and prepare binned columns. """Load results file and prepare binned columns.
Args: Args:
file_path: Path to the results parquet file. file_path: Path to the results parquet file.
settings: Dictionary of search settings.
k_bin_width: Width of bins for initial_K parameter. k_bin_width: Width of bins for initial_K parameter.
Returns: Returns:
@ -50,21 +51,21 @@ def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataF
results = pd.read_parquet(file_path) results = pd.read_parquet(file_path)
# Automatically determine bin width for initial_K based on data range # Automatically determine bin width for initial_K based on data range
k_min = results["initial_K"].min() k_min = settings["param_grid"]["initial_K"]["low"]
k_max = results["initial_K"].max() k_max = settings["param_grid"]["initial_K"]["high"]
# Use configurable bin width, adapted to actual data range # Use configurable bin width, adapted to actual data range
k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width) k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width)
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False) results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False)
# Automatically create logarithmic bins for epsilon parameters based on data range # Automatically create logarithmic bins for epsilon parameters based on data range
# Use 10 bins spanning the actual data range # Use 10 bins spanning the actual data range
eps_cl_min = np.log10(results["eps_cl"].min()) eps_cl_min = np.log10(settings["param_grid"]["eps_cl"]["low"])
eps_cl_max = np.log10(results["eps_cl"].max()) eps_cl_max = np.log10(settings["param_grid"]["eps_cl"]["high"])
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10) eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=int(eps_cl_max - eps_cl_min + 1))
eps_e_min = np.log10(results["eps_e"].min()) eps_e_min = np.log10(settings["param_grid"]["eps_e"]["low"])
eps_e_max = np.log10(results["eps_e"].max()) eps_e_max = np.log10(settings["param_grid"]["eps_e"]["high"])
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10) eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=int(eps_e_max - eps_e_min + 1))
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins) results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins) results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
@ -184,29 +185,46 @@ def extract_common_features(model_state: xr.Dataset) -> xr.DataArray | None:
return common_feature_array return common_feature_array
def _extract_probs_as_xdggs(preds: gpd.GeoDataFrame, settings: dict) -> xr.DataArray:
grid = settings["grid"]
level = settings["level"]
probs = xr.DataArray(
data=preds["predicted_proba"].to_numpy(),
dims=["cell_ids"],
coords={"cell_ids": preds["cell_id"].to_numpy()},
)
gridinfo = {
"grid_name": "h3" if grid == "hex" else grid,
"level": level,
}
if grid == "healpix":
gridinfo["indexing_scheme"] = "nested"
probs.cell_ids.attrs = gridinfo
probs = xdggs.decode(probs)
return probs
def _plot_prediction_map(preds: gpd.GeoDataFrame): def _plot_prediction_map(preds: gpd.GeoDataFrame):
return preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron") return preds.explore(column="predicted_class", cmap="Set3", legend=True, tiles="CartoDB positron")
# return probs.dggs.explore()
def _plot_prediction_class_distribution(preds: gpd.GeoDataFrame):
"""Create a bar chart showing the count of each predicted class.
Args:
preds: GeoDataFrame with predicted classes.
Returns:
Altair chart showing class distribution.
"""
df = pd.DataFrame({"predicted_class": preds["predicted_class"].to_numpy()})
counts = df["predicted_class"].value_counts().reset_index()
counts.columns = ["class", "count"]
counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2)
chart = (
alt.Chart(counts)
.mark_bar()
.encode(
x=alt.X("class:N", title="Predicted Class"),
y=alt.Y("count:Q", title="Number of Cells"),
color=alt.Color("class:N", title="Class", legend=None),
tooltip=[
alt.Tooltip("class:N", title="Class"),
alt.Tooltip("count:Q", title="Count"),
alt.Tooltip("percentage:Q", format=".2f", title="Percentage (%)"),
],
)
.properties(
width=400,
height=300,
title="Predicted Class Distribution",
)
)
return chart
def _plot_k_binned( def _plot_k_binned(
@ -740,6 +758,55 @@ def _plot_common_features(common_array: xr.DataArray):
return chart return chart
def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str, color_param: str):
"""Create a scatter plot comparing two metrics with parameter-based coloring.
Args:
results: DataFrame containing the results with metric columns.
x_metric: Name of the metric to plot on x-axis (e.g., 'precision', 'accuracy').
y_metric: Name of the metric to plot on y-axis (e.g., 'recall', 'jaccard').
color_param: Parameter to use for coloring ('initial_K', 'eps_cl', or 'eps_e').
Returns:
Altair chart showing the metric comparison.
"""
# Determine color scale based on parameter
if color_param in ["eps_cl", "eps_e"]:
color_scale = alt.Scale(type="log", scheme="viridis")
else:
color_scale = alt.Scale(scheme="viridis")
chart = (
alt.Chart(results)
.mark_circle(size=60, opacity=0.7)
.encode(
x=alt.X(f"mean_test_{x_metric}:Q", title=x_metric.capitalize(), scale=alt.Scale(zero=False)),
y=alt.Y(f"mean_test_{y_metric}:Q", title=y_metric.capitalize(), scale=alt.Scale(zero=False)),
color=alt.Color(
f"{color_param}:Q",
scale=color_scale,
title=color_param,
),
tooltip=[
alt.Tooltip(f"mean_test_{x_metric}:Q", format=".4f", title=x_metric.capitalize()),
alt.Tooltip(f"mean_test_{y_metric}:Q", format=".4f", title=y_metric.capitalize()),
alt.Tooltip("mean_test_f1:Q", format=".4f", title="F1"),
alt.Tooltip("initial_K:Q", format=".0f", title="initial_K"),
alt.Tooltip("eps_cl:Q", format=".2e", title="eps_cl"),
alt.Tooltip("eps_e:Q", format=".2e", title="eps_e"),
],
)
.properties(
width=400,
height=400,
)
.interactive()
)
return chart
def main(): def main():
"""Run Streamlit dashboard application.""" """Run Streamlit dashboard application."""
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide") st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
@ -769,7 +836,8 @@ def main():
# Load and prepare data with default bin width (will be reloaded with custom width later) # Load and prepare data with default bin width (will be reloaded with custom width later)
with st.spinner("Loading results..."): with st.spinner("Loading results..."):
results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=40) settings = toml.load(results_dir / "search_settings.toml")["settings"]
results = load_and_prepare_results(results_dir / "search_results.parquet", settings, k_bin_width=40)
model_state = load_and_prepare_model_state(results_dir / "best_estimator_state.nc") model_state = load_and_prepare_model_state(results_dir / "best_estimator_state.nc")
n_features = model_state.sizes["feature"] n_features = model_state.sizes["feature"]
model_state["feature_weights"] *= n_features model_state["feature_weights"] *= n_features
@ -777,40 +845,62 @@ def main():
era5_feature_array = extract_era5_features(model_state) era5_feature_array = extract_era5_features(model_state)
common_feature_array = extract_common_features(model_state) common_feature_array = extract_common_features(model_state)
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413") predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
settings = toml.load(results_dir / "search_settings.toml")["settings"]
probs = _extract_probs_as_xdggs(predictions, settings)
st.sidebar.success(f"Loaded {len(results)} results") st.sidebar.success(f"Loaded {len(results)} results")
# Metric selection
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
selected_metric = st.sidebar.selectbox(
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
)
# Display some basic statistics first (lightweight) # Display some basic statistics first (lightweight)
st.header("Dataset Overview") st.header("Parameter-Search Overview")
col1, col2, col3 = st.columns(3)
# Show total runs and best model info
col1, col2 = st.columns(2)
with col1: with col1:
st.metric("Total Runs", len(results)) st.metric("Total Runs", len(results))
with col2:
best_score = results[f"mean_test_{selected_metric}"].max()
st.metric(f"Best {selected_metric.capitalize()}", f"{best_score:.4f}")
with col3:
best_idx = results[f"mean_test_{selected_metric}"].idxmax()
best_k = results.loc[best_idx, "initial_K"]
st.metric("Best K", f"{best_k:.0f}")
# Show best parameters with col2:
with st.expander("Best Parameters"): # Best model based on F1 score
best_idx = results[f"mean_test_{selected_metric}"].idxmax() best_f1_idx = results["mean_test_f1"].idxmax()
best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]] st.metric("Best Model Index (by F1)", f"#{best_f1_idx}")
st.dataframe(best_params.to_frame().T, width="content")
# Show best parameters for the best model
with st.expander("Best Model Parameters (by F1)", expanded=True):
best_params = results.loc[best_f1_idx, ["initial_K", "eps_cl", "eps_e", "mean_test_f1", "std_test_f1"]]
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("initial_K", f"{best_params['initial_K']:.0f}")
with col2:
st.metric("eps_cl", f"{best_params['eps_cl']:.2e}")
with col3:
st.metric("eps_e", f"{best_params['eps_e']:.2e}")
with col4:
st.metric("F1 Score", f"{best_params['mean_test_f1']:.4f}")
with col5:
st.metric("F1 Std", f"{best_params['std_test_f1']:.4f}")
# Show all metrics
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
cols = st.columns(len(available_metrics))
for idx, metric in enumerate(available_metrics):
with cols[idx]:
best_score = results.loc[best_f1_idx, f"mean_test_{metric}"]
st.metric(f"{metric.capitalize()}", f"{best_score:.4f}")
# Create tabs for different visualizations # Create tabs for different visualizations
tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Predictions Map"]) tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Inference Analysis"])
with tab1: with tab1:
# Metric selection - only used in this tab
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
selected_metric = st.selectbox(
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
)
# Show best parameters
with st.expander("Best Parameters"):
best_idx = results[f"mean_test_{selected_metric}"].idxmax()
best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]]
st.dataframe(best_params.to_frame().T, width="content")
# Main plots # Main plots
st.header(f"Visualization for {selected_metric.capitalize()}") st.header(f"Visualization for {selected_metric.capitalize()}")
@ -847,7 +937,9 @@ def main():
# Reload data if bin width changed from default # Reload data if bin width changed from default
if k_bin_width != 40: if k_bin_width != 40:
with st.spinner("Re-binning data..."): with st.spinner("Re-binning data..."):
results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=k_bin_width) results = load_and_prepare_results(
results_dir / "search_results.parquet", settings, k_bin_width=k_bin_width
)
# K-binned plots # K-binned plots
col1, col2 = st.columns(2) col1, col2 = st.columns(2)
@ -885,6 +977,30 @@ def main():
chart4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}") chart4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}")
st.altair_chart(chart4, use_container_width=True) st.altair_chart(chart4, use_container_width=True)
# Metric comparison plots
st.header("Metric Comparisons")
# Color parameter selection
color_param = st.selectbox(
"Select Color Parameter",
options=["initial_K", "eps_cl", "eps_e"],
help="Choose which parameter to use for coloring the scatter plots",
)
col1, col2 = st.columns(2)
with col1:
st.subheader("Recall vs Precision")
with st.spinner("Generating Recall vs Precision plot..."):
recall_precision_chart = _plot_metric_comparison(results, "precision", "recall", color_param)
st.altair_chart(recall_precision_chart, use_container_width=True)
with col2:
st.subheader("Accuracy vs Jaccard")
with st.spinner("Generating Accuracy vs Jaccard plot..."):
accuracy_jaccard_chart = _plot_metric_comparison(results, "accuracy", "jaccard", color_param)
st.altair_chart(accuracy_jaccard_chart, use_container_width=True)
# Optional: Raw data table # Optional: Raw data table
with st.expander("View Raw Results Data"): with st.expander("View Raw Results Data"):
st.dataframe(results, width="stretch") st.dataframe(results, width="stretch")
@ -993,164 +1109,226 @@ def main():
# Embedding features analysis (if present) # Embedding features analysis (if present)
if embedding_feature_array is not None: if embedding_feature_array is not None:
st.subheader(":artificial_satellite: Embedding Feature Analysis") with st.container(border=True):
st.markdown( st.header(":artificial_satellite: Embedding Feature Analysis")
""" st.markdown(
Analysis of embedding features showing which aggregations, bands, and years """
are most important for the model predictions. Analysis of embedding features showing which aggregations, bands, and years
""" are most important for the model predictions.
) """
)
# Summary bar charts # Summary bar charts
st.markdown("### Importance by Dimension") st.markdown("### Importance by Dimension")
with st.spinner("Generating dimension summaries..."): with st.spinner("Generating dimension summaries..."):
chart_agg, chart_band, chart_year = _plot_embedding_aggregation_summary(embedding_feature_array) chart_agg, chart_band, chart_year = _plot_embedding_aggregation_summary(embedding_feature_array)
col1, col2, col3 = st.columns(3) col1, col2, col3 = st.columns(3)
with col1: with col1:
st.altair_chart(chart_agg, use_container_width=True) st.altair_chart(chart_agg, use_container_width=True)
with col2: with col2:
st.altair_chart(chart_band, use_container_width=True) st.altair_chart(chart_band, use_container_width=True)
with col3: with col3:
st.altair_chart(chart_year, use_container_width=True) st.altair_chart(chart_year, use_container_width=True)
# Detailed heatmap # Detailed heatmap
st.markdown("### Detailed Heatmap by Aggregation") st.markdown("### Detailed Heatmap by Aggregation")
st.markdown("Shows the weight of each band-year combination for each aggregation type.") st.markdown("Shows the weight of each band-year combination for each aggregation type.")
with st.spinner("Generating heatmap..."): with st.spinner("Generating heatmap..."):
heatmap_chart = _plot_embedding_heatmap(embedding_feature_array) heatmap_chart = _plot_embedding_heatmap(embedding_feature_array)
st.altair_chart(heatmap_chart, use_container_width=True) st.altair_chart(heatmap_chart, use_container_width=True)
# Statistics # Statistics
with st.expander("Embedding Feature Statistics"): with st.expander("Embedding Feature Statistics"):
st.write("**Overall Statistics:**") st.write("**Overall Statistics:**")
n_emb_features = embedding_feature_array.size n_emb_features = embedding_feature_array.size
mean_weight = float(embedding_feature_array.mean().values) mean_weight = float(embedding_feature_array.mean().values)
max_weight = float(embedding_feature_array.max().values) max_weight = float(embedding_feature_array.max().values)
col1, col2, col3 = st.columns(3) col1, col2, col3 = st.columns(3)
with col1: with col1:
st.metric("Total Embedding Features", n_emb_features) st.metric("Total Embedding Features", n_emb_features)
with col2: with col2:
st.metric("Mean Weight", f"{mean_weight:.4f}") st.metric("Mean Weight", f"{mean_weight:.4f}")
with col3: with col3:
st.metric("Max Weight", f"{max_weight:.4f}") st.metric("Max Weight", f"{max_weight:.4f}")
# Show top embedding features # Show top embedding features
st.write("**Top 10 Embedding Features:**") st.write("**Top 10 Embedding Features:**")
emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index() emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index()
top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]] top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]]
st.dataframe(top_emb, width="stretch") st.dataframe(top_emb, width="stretch")
else: else:
st.info("No embedding features found in this model.") st.info("No embedding features found in this model.")
# ERA5 features analysis (if present) # ERA5 features analysis (if present)
if era5_feature_array is not None: if era5_feature_array is not None:
st.subheader(":partly_sunny: ERA5 Feature Analysis") with st.container(border=True):
st.markdown( st.header(":partly_sunny: ERA5 Feature Analysis")
""" st.markdown(
Analysis of ERA5 climate features showing which variables and time periods """
are most important for the model predictions. Analysis of ERA5 climate features showing which variables and time periods
""" are most important for the model predictions.
) """
)
# Summary bar charts # Summary bar charts
st.markdown("### Importance by Dimension") st.markdown("### Importance by Dimension")
with st.spinner("Generating ERA5 dimension summaries..."): with st.spinner("Generating ERA5 dimension summaries..."):
chart_variable, chart_time = _plot_era5_summary(era5_feature_array) chart_variable, chart_time = _plot_era5_summary(era5_feature_array)
col1, col2 = st.columns(2) col1, col2 = st.columns(2)
with col1: with col1:
st.altair_chart(chart_variable, use_container_width=True) st.altair_chart(chart_variable, use_container_width=True)
with col2: with col2:
st.altair_chart(chart_time, use_container_width=True) st.altair_chart(chart_time, use_container_width=True)
# Detailed heatmap # Detailed heatmap
st.markdown("### Detailed Heatmap") st.markdown("### Detailed Heatmap")
st.markdown("Shows the weight of each variable-time combination.") st.markdown("Shows the weight of each variable-time combination.")
with st.spinner("Generating ERA5 heatmap..."): with st.spinner("Generating ERA5 heatmap..."):
era5_heatmap_chart = _plot_era5_heatmap(era5_feature_array) era5_heatmap_chart = _plot_era5_heatmap(era5_feature_array)
st.altair_chart(era5_heatmap_chart, use_container_width=True) st.altair_chart(era5_heatmap_chart, use_container_width=True)
# Statistics # Statistics
with st.expander("ERA5 Feature Statistics"): with st.expander("ERA5 Feature Statistics"):
st.write("**Overall Statistics:**") st.write("**Overall Statistics:**")
n_era5_features = era5_feature_array.size n_era5_features = era5_feature_array.size
mean_weight = float(era5_feature_array.mean().values) mean_weight = float(era5_feature_array.mean().values)
max_weight = float(era5_feature_array.max().values) max_weight = float(era5_feature_array.max().values)
col1, col2, col3 = st.columns(3) col1, col2, col3 = st.columns(3)
with col1: with col1:
st.metric("Total ERA5 Features", n_era5_features) st.metric("Total ERA5 Features", n_era5_features)
with col2: with col2:
st.metric("Mean Weight", f"{mean_weight:.4f}") st.metric("Mean Weight", f"{mean_weight:.4f}")
with col3: with col3:
st.metric("Max Weight", f"{max_weight:.4f}") st.metric("Max Weight", f"{max_weight:.4f}")
# Show top ERA5 features # Show top ERA5 features
st.write("**Top 10 ERA5 Features:**") st.write("**Top 10 ERA5 Features:**")
era5_df = era5_feature_array.to_dataframe(name="weight").reset_index() era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]] top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]]
st.dataframe(top_era5, width="stretch") st.dataframe(top_era5, width="stretch")
else: else:
st.info("No ERA5 features found in this model.") st.info("No ERA5 features found in this model.")
# Common features analysis (if present) # Common features analysis (if present)
if common_feature_array is not None: if common_feature_array is not None:
st.subheader(":world_map: Common Feature Analysis") with st.container(border=True):
st.markdown( st.header(":world_map: Common Feature Analysis")
""" st.markdown(
Analysis of common features including cell area, water area, land area, land ratio, """
longitude, and latitude. These features provide spatial and geographic context. Analysis of common features including cell area, water area, land area, land ratio,
""" longitude, and latitude. These features provide spatial and geographic context.
) """
)
# Bar chart showing all common feature weights # Bar chart showing all common feature weights
with st.spinner("Generating common features chart..."): with st.spinner("Generating common features chart..."):
common_chart = _plot_common_features(common_feature_array) common_chart = _plot_common_features(common_feature_array)
st.altair_chart(common_chart, use_container_width=True) st.altair_chart(common_chart, use_container_width=True)
# Statistics # Statistics
with st.expander("Common Feature Statistics"): with st.expander("Common Feature Statistics"):
st.write("**Overall Statistics:**") st.write("**Overall Statistics:**")
n_common_features = common_feature_array.size n_common_features = common_feature_array.size
mean_weight = float(common_feature_array.mean().values) mean_weight = float(common_feature_array.mean().values)
max_weight = float(common_feature_array.max().values) max_weight = float(common_feature_array.max().values)
min_weight = float(common_feature_array.min().values) min_weight = float(common_feature_array.min().values)
col1, col2, col3, col4 = st.columns(4) col1, col2, col3, col4 = st.columns(4)
with col1: with col1:
st.metric("Total Common Features", n_common_features) st.metric("Total Common Features", n_common_features)
with col2: with col2:
st.metric("Mean Weight", f"{mean_weight:.4f}") st.metric("Mean Weight", f"{mean_weight:.4f}")
with col3: with col3:
st.metric("Max Weight", f"{max_weight:.4f}") st.metric("Max Weight", f"{max_weight:.4f}")
with col4: with col4:
st.metric("Min Weight", f"{min_weight:.4f}") st.metric("Min Weight", f"{min_weight:.4f}")
# Show all common features sorted by importance # Show all common features sorted by importance
st.write("**All Common Features (by absolute weight):**") st.write("**All Common Features (by absolute weight):**")
common_df = common_feature_array.to_dataframe(name="weight").reset_index() common_df = common_feature_array.to_dataframe(name="weight").reset_index()
common_df["abs_weight"] = common_df["weight"].abs() common_df["abs_weight"] = common_df["weight"].abs()
common_df = common_df.sort_values("abs_weight", ascending=False) common_df = common_df.sort_values("abs_weight", ascending=False)
st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch") st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch")
st.markdown( st.markdown(
""" """
**Interpretation:** **Interpretation:**
- **cell_area, water_area, land_area**: Spatial extent features that may indicate size-related patterns - **cell_area, water_area, land_area**: Spatial extent features that may indicate
- **land_ratio**: Proportion of land vs water in each cell size-related patterns
- **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns - **land_ratio**: Proportion of land vs water in each cell
- Positive weights indicate features that increase the probability of the positive class - **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
- Negative weights indicate features that decrease the probability of the positive class - Positive weights indicate features that increase the probability of the positive class
""" - Negative weights indicate features that decrease the probability of the positive class
) """
)
else: else:
st.info("No common features found in this model.") st.info("No common features found in this model.")
with tab3: with tab3:
# Map visualization # Inference analysis
st.header("Predictions Map") st.header("Inference Analysis")
st.markdown("Map showing predicted classes from the best estimator") st.markdown("Comprehensive analysis of model predictions on the evaluation dataset")
with st.spinner("Generating map..."):
# Summary statistics
st.subheader("Prediction Statistics")
col1, col2 = st.columns(2)
with col1:
total_cells = len(predictions)
st.metric("Total Cells", f"{total_cells:,}")
with col2:
n_classes = predictions["predicted_class"].nunique()
st.metric("Number of Classes", n_classes)
# Class distribution visualization
st.subheader("Class Distribution")
with st.spinner("Generating class distribution..."):
class_dist_chart = _plot_prediction_class_distribution(predictions)
st.altair_chart(class_dist_chart, use_container_width=True)
st.markdown(
"""
**Interpretation:**
- Shows the balance between predicted classes
- Class imbalance may indicate regional patterns or model bias
- Each bar represents the count of cells predicted for that class
"""
)
# Interactive map
st.subheader("Interactive Prediction Map")
st.markdown("Explore predictions spatially with the interactive map below")
with st.spinner("Generating interactive map..."):
chart_map = _plot_prediction_map(predictions) chart_map = _plot_prediction_map(predictions)
st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[]) st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[])
# Additional statistics in expander
with st.expander("Detailed Prediction Statistics"):
st.write("**Class Distribution:**")
class_counts = predictions["predicted_class"].value_counts().sort_index()
# Create columns for better layout
n_cols = min(5, len(class_counts))
cols = st.columns(n_cols)
for idx, (class_label, count) in enumerate(class_counts.items()):
percentage = count / len(predictions) * 100
with cols[idx % n_cols]:
st.metric(f"Class {class_label}", f"{count:,} ({percentage:.2f}%)")
# Show detailed table
st.write("**Detailed Class Breakdown:**")
class_df = pd.DataFrame(
{
"Class": class_counts.index,
"Count": class_counts.to_numpy(),
"Percentage": (class_counts.to_numpy() / len(predictions) * 100).round(2),
}
)
st.dataframe(class_df, width="stretch", hide_index=True)
if __name__ == "__main__": if __name__ == "__main__":
main() main()