Switch to predict_classes
This commit is contained in:
parent
d4a747d800
commit
b2cfddfead
3 changed files with 373 additions and 192 deletions
|
|
@ -18,13 +18,14 @@ pretty.install()
|
|||
set_config(array_api_dispatch=True)
|
||||
|
||||
|
||||
def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier):
|
||||
def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier, classes: list) -> gpd.GeoDataFrame:
|
||||
"""Get predicted probabilities for each cell.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
clf (ESPAClassifier): The trained classifier to use for predictions.
|
||||
classes (list): List of class names.
|
||||
|
||||
Returns:
|
||||
list: A list of predicted probabilities for each cell.
|
||||
|
|
@ -46,11 +47,11 @@ def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifi
|
|||
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
|
||||
X_batch = X_batch.to_numpy(dtype="float32")
|
||||
X_batch = torch.asarray(X_batch, device=0)
|
||||
batch_preds = clf.predict_proba(X_batch)[:, 1].cpu().numpy()
|
||||
batch_preds = clf.predict(X_batch).cpu().numpy()
|
||||
batch_preds = gpd.GeoDataFrame(
|
||||
{
|
||||
"cell_id": cell_ids,
|
||||
"predicted_proba": batch_preds,
|
||||
"predicted_class": [classes[i] for i in batch_preds],
|
||||
"geometry": cell_geoms,
|
||||
},
|
||||
crs="epsg:3413",
|
||||
|
|
|
|||
|
|
@ -29,13 +29,14 @@ pretty.install()
|
|||
set_config(array_api_dispatch=True)
|
||||
|
||||
|
||||
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
||||
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, robust: bool = False):
|
||||
"""Perform random cross-validation on the training dataset.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
||||
robust (bool, optional): Whether to use robust training. Defaults to False.
|
||||
|
||||
"""
|
||||
data = get_train_dataset_file(grid=grid, level=level)
|
||||
|
|
@ -59,7 +60,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
|||
"initial_K": randint(20, 400),
|
||||
}
|
||||
|
||||
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42)
|
||||
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
|
||||
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
||||
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
|
||||
search = RandomizedSearchCV(
|
||||
|
|
@ -69,7 +70,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
|||
n_jobs=16,
|
||||
cv=cv,
|
||||
random_state=42,
|
||||
verbose=10,
|
||||
verbose=3,
|
||||
scoring=metrics,
|
||||
refit="f1",
|
||||
)
|
||||
|
|
@ -114,6 +115,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
|||
},
|
||||
"cv_splits": cv.get_n_splits(),
|
||||
"metrics": metrics,
|
||||
"classes": ["No RTS", "RTS"],
|
||||
}
|
||||
settings_file = results_dir / "search_settings.toml"
|
||||
print(f"Storing search settings to {settings_file}")
|
||||
|
|
@ -183,7 +185,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
|||
|
||||
# Predict probabilities for all cells
|
||||
print("Predicting probabilities for all cells...")
|
||||
preds = predict_proba(grid=grid, level=level, clf=best_estimator)
|
||||
preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=settings["classes"])
|
||||
preds_file = results_dir / "predicted_probabilities.parquet"
|
||||
print(f"Storing predicted probabilities to {preds_file}")
|
||||
preds.to_parquet(preds_file)
|
||||
|
|
|
|||
|
|
@ -36,11 +36,12 @@ def get_available_result_files() -> list[Path]:
|
|||
return sorted(result_files, reverse=True) # Most recent first
|
||||
|
||||
|
||||
def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame:
|
||||
def load_and_prepare_results(file_path: Path, settings: dict, k_bin_width: int = 40) -> pd.DataFrame:
|
||||
"""Load results file and prepare binned columns.
|
||||
|
||||
Args:
|
||||
file_path: Path to the results parquet file.
|
||||
settings: Dictionary of search settings.
|
||||
k_bin_width: Width of bins for initial_K parameter.
|
||||
|
||||
Returns:
|
||||
|
|
@ -50,21 +51,21 @@ def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataF
|
|||
results = pd.read_parquet(file_path)
|
||||
|
||||
# Automatically determine bin width for initial_K based on data range
|
||||
k_min = results["initial_K"].min()
|
||||
k_max = results["initial_K"].max()
|
||||
k_min = settings["param_grid"]["initial_K"]["low"]
|
||||
k_max = settings["param_grid"]["initial_K"]["high"]
|
||||
# Use configurable bin width, adapted to actual data range
|
||||
k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width)
|
||||
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False)
|
||||
|
||||
# Automatically create logarithmic bins for epsilon parameters based on data range
|
||||
# Use 10 bins spanning the actual data range
|
||||
eps_cl_min = np.log10(results["eps_cl"].min())
|
||||
eps_cl_max = np.log10(results["eps_cl"].max())
|
||||
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10)
|
||||
eps_cl_min = np.log10(settings["param_grid"]["eps_cl"]["low"])
|
||||
eps_cl_max = np.log10(settings["param_grid"]["eps_cl"]["high"])
|
||||
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=int(eps_cl_max - eps_cl_min + 1))
|
||||
|
||||
eps_e_min = np.log10(results["eps_e"].min())
|
||||
eps_e_max = np.log10(results["eps_e"].max())
|
||||
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10)
|
||||
eps_e_min = np.log10(settings["param_grid"]["eps_e"]["low"])
|
||||
eps_e_max = np.log10(settings["param_grid"]["eps_e"]["high"])
|
||||
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=int(eps_e_max - eps_e_min + 1))
|
||||
|
||||
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
|
||||
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
|
||||
|
|
@ -184,29 +185,46 @@ def extract_common_features(model_state: xr.Dataset) -> xr.DataArray | None:
|
|||
return common_feature_array
|
||||
|
||||
|
||||
def _extract_probs_as_xdggs(preds: gpd.GeoDataFrame, settings: dict) -> xr.DataArray:
|
||||
grid = settings["grid"]
|
||||
level = settings["level"]
|
||||
|
||||
probs = xr.DataArray(
|
||||
data=preds["predicted_proba"].to_numpy(),
|
||||
dims=["cell_ids"],
|
||||
coords={"cell_ids": preds["cell_id"].to_numpy()},
|
||||
)
|
||||
gridinfo = {
|
||||
"grid_name": "h3" if grid == "hex" else grid,
|
||||
"level": level,
|
||||
}
|
||||
if grid == "healpix":
|
||||
gridinfo["indexing_scheme"] = "nested"
|
||||
probs.cell_ids.attrs = gridinfo
|
||||
probs = xdggs.decode(probs)
|
||||
return probs
|
||||
|
||||
|
||||
def _plot_prediction_map(preds: gpd.GeoDataFrame):
|
||||
return preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron")
|
||||
# return probs.dggs.explore()
|
||||
return preds.explore(column="predicted_class", cmap="Set3", legend=True, tiles="CartoDB positron")
|
||||
|
||||
|
||||
def _plot_prediction_class_distribution(preds: gpd.GeoDataFrame):
|
||||
"""Create a bar chart showing the count of each predicted class.
|
||||
|
||||
Args:
|
||||
preds: GeoDataFrame with predicted classes.
|
||||
|
||||
Returns:
|
||||
Altair chart showing class distribution.
|
||||
|
||||
"""
|
||||
df = pd.DataFrame({"predicted_class": preds["predicted_class"].to_numpy()})
|
||||
counts = df["predicted_class"].value_counts().reset_index()
|
||||
counts.columns = ["class", "count"]
|
||||
counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2)
|
||||
|
||||
chart = (
|
||||
alt.Chart(counts)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
x=alt.X("class:N", title="Predicted Class"),
|
||||
y=alt.Y("count:Q", title="Number of Cells"),
|
||||
color=alt.Color("class:N", title="Class", legend=None),
|
||||
tooltip=[
|
||||
alt.Tooltip("class:N", title="Class"),
|
||||
alt.Tooltip("count:Q", title="Count"),
|
||||
alt.Tooltip("percentage:Q", format=".2f", title="Percentage (%)"),
|
||||
],
|
||||
)
|
||||
.properties(
|
||||
width=400,
|
||||
height=300,
|
||||
title="Predicted Class Distribution",
|
||||
)
|
||||
)
|
||||
|
||||
return chart
|
||||
|
||||
|
||||
def _plot_k_binned(
|
||||
|
|
@ -740,6 +758,55 @@ def _plot_common_features(common_array: xr.DataArray):
|
|||
return chart
|
||||
|
||||
|
||||
def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str, color_param: str):
|
||||
"""Create a scatter plot comparing two metrics with parameter-based coloring.
|
||||
|
||||
Args:
|
||||
results: DataFrame containing the results with metric columns.
|
||||
x_metric: Name of the metric to plot on x-axis (e.g., 'precision', 'accuracy').
|
||||
y_metric: Name of the metric to plot on y-axis (e.g., 'recall', 'jaccard').
|
||||
color_param: Parameter to use for coloring ('initial_K', 'eps_cl', or 'eps_e').
|
||||
|
||||
Returns:
|
||||
Altair chart showing the metric comparison.
|
||||
|
||||
"""
|
||||
# Determine color scale based on parameter
|
||||
if color_param in ["eps_cl", "eps_e"]:
|
||||
color_scale = alt.Scale(type="log", scheme="viridis")
|
||||
else:
|
||||
color_scale = alt.Scale(scheme="viridis")
|
||||
|
||||
chart = (
|
||||
alt.Chart(results)
|
||||
.mark_circle(size=60, opacity=0.7)
|
||||
.encode(
|
||||
x=alt.X(f"mean_test_{x_metric}:Q", title=x_metric.capitalize(), scale=alt.Scale(zero=False)),
|
||||
y=alt.Y(f"mean_test_{y_metric}:Q", title=y_metric.capitalize(), scale=alt.Scale(zero=False)),
|
||||
color=alt.Color(
|
||||
f"{color_param}:Q",
|
||||
scale=color_scale,
|
||||
title=color_param,
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip(f"mean_test_{x_metric}:Q", format=".4f", title=x_metric.capitalize()),
|
||||
alt.Tooltip(f"mean_test_{y_metric}:Q", format=".4f", title=y_metric.capitalize()),
|
||||
alt.Tooltip("mean_test_f1:Q", format=".4f", title="F1"),
|
||||
alt.Tooltip("initial_K:Q", format=".0f", title="initial_K"),
|
||||
alt.Tooltip("eps_cl:Q", format=".2e", title="eps_cl"),
|
||||
alt.Tooltip("eps_e:Q", format=".2e", title="eps_e"),
|
||||
],
|
||||
)
|
||||
.properties(
|
||||
width=400,
|
||||
height=400,
|
||||
)
|
||||
.interactive()
|
||||
)
|
||||
|
||||
return chart
|
||||
|
||||
|
||||
def main():
|
||||
"""Run Streamlit dashboard application."""
|
||||
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
|
||||
|
|
@ -769,7 +836,8 @@ def main():
|
|||
|
||||
# Load and prepare data with default bin width (will be reloaded with custom width later)
|
||||
with st.spinner("Loading results..."):
|
||||
results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=40)
|
||||
settings = toml.load(results_dir / "search_settings.toml")["settings"]
|
||||
results = load_and_prepare_results(results_dir / "search_results.parquet", settings, k_bin_width=40)
|
||||
model_state = load_and_prepare_model_state(results_dir / "best_estimator_state.nc")
|
||||
n_features = model_state.sizes["feature"]
|
||||
model_state["feature_weights"] *= n_features
|
||||
|
|
@ -777,40 +845,62 @@ def main():
|
|||
era5_feature_array = extract_era5_features(model_state)
|
||||
common_feature_array = extract_common_features(model_state)
|
||||
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
|
||||
settings = toml.load(results_dir / "search_settings.toml")["settings"]
|
||||
probs = _extract_probs_as_xdggs(predictions, settings)
|
||||
|
||||
st.sidebar.success(f"Loaded {len(results)} results")
|
||||
|
||||
# Metric selection
|
||||
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
|
||||
selected_metric = st.sidebar.selectbox(
|
||||
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
|
||||
)
|
||||
|
||||
# Display some basic statistics first (lightweight)
|
||||
st.header("Dataset Overview")
|
||||
col1, col2, col3 = st.columns(3)
|
||||
st.header("Parameter-Search Overview")
|
||||
|
||||
# Show total runs and best model info
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.metric("Total Runs", len(results))
|
||||
with col2:
|
||||
best_score = results[f"mean_test_{selected_metric}"].max()
|
||||
st.metric(f"Best {selected_metric.capitalize()}", f"{best_score:.4f}")
|
||||
with col3:
|
||||
best_idx = results[f"mean_test_{selected_metric}"].idxmax()
|
||||
best_k = results.loc[best_idx, "initial_K"]
|
||||
st.metric("Best K", f"{best_k:.0f}")
|
||||
|
||||
# Show best parameters
|
||||
with st.expander("Best Parameters"):
|
||||
best_idx = results[f"mean_test_{selected_metric}"].idxmax()
|
||||
best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]]
|
||||
st.dataframe(best_params.to_frame().T, width="content")
|
||||
with col2:
|
||||
# Best model based on F1 score
|
||||
best_f1_idx = results["mean_test_f1"].idxmax()
|
||||
st.metric("Best Model Index (by F1)", f"#{best_f1_idx}")
|
||||
|
||||
# Show best parameters for the best model
|
||||
with st.expander("Best Model Parameters (by F1)", expanded=True):
|
||||
best_params = results.loc[best_f1_idx, ["initial_K", "eps_cl", "eps_e", "mean_test_f1", "std_test_f1"]]
|
||||
col1, col2, col3, col4, col5 = st.columns(5)
|
||||
with col1:
|
||||
st.metric("initial_K", f"{best_params['initial_K']:.0f}")
|
||||
with col2:
|
||||
st.metric("eps_cl", f"{best_params['eps_cl']:.2e}")
|
||||
with col3:
|
||||
st.metric("eps_e", f"{best_params['eps_e']:.2e}")
|
||||
with col4:
|
||||
st.metric("F1 Score", f"{best_params['mean_test_f1']:.4f}")
|
||||
with col5:
|
||||
st.metric("F1 Std", f"{best_params['std_test_f1']:.4f}")
|
||||
|
||||
# Show all metrics
|
||||
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
|
||||
cols = st.columns(len(available_metrics))
|
||||
|
||||
for idx, metric in enumerate(available_metrics):
|
||||
with cols[idx]:
|
||||
best_score = results.loc[best_f1_idx, f"mean_test_{metric}"]
|
||||
st.metric(f"{metric.capitalize()}", f"{best_score:.4f}")
|
||||
|
||||
# Create tabs for different visualizations
|
||||
tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Predictions Map"])
|
||||
tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Inference Analysis"])
|
||||
|
||||
with tab1:
|
||||
# Metric selection - only used in this tab
|
||||
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
|
||||
selected_metric = st.selectbox(
|
||||
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
|
||||
)
|
||||
|
||||
# Show best parameters
|
||||
with st.expander("Best Parameters"):
|
||||
best_idx = results[f"mean_test_{selected_metric}"].idxmax()
|
||||
best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]]
|
||||
st.dataframe(best_params.to_frame().T, width="content")
|
||||
|
||||
# Main plots
|
||||
st.header(f"Visualization for {selected_metric.capitalize()}")
|
||||
|
||||
|
|
@ -847,7 +937,9 @@ def main():
|
|||
# Reload data if bin width changed from default
|
||||
if k_bin_width != 40:
|
||||
with st.spinner("Re-binning data..."):
|
||||
results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=k_bin_width)
|
||||
results = load_and_prepare_results(
|
||||
results_dir / "search_results.parquet", settings, k_bin_width=k_bin_width
|
||||
)
|
||||
|
||||
# K-binned plots
|
||||
col1, col2 = st.columns(2)
|
||||
|
|
@ -885,6 +977,30 @@ def main():
|
|||
chart4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}")
|
||||
st.altair_chart(chart4, use_container_width=True)
|
||||
|
||||
# Metric comparison plots
|
||||
st.header("Metric Comparisons")
|
||||
|
||||
# Color parameter selection
|
||||
color_param = st.selectbox(
|
||||
"Select Color Parameter",
|
||||
options=["initial_K", "eps_cl", "eps_e"],
|
||||
help="Choose which parameter to use for coloring the scatter plots",
|
||||
)
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.subheader("Recall vs Precision")
|
||||
with st.spinner("Generating Recall vs Precision plot..."):
|
||||
recall_precision_chart = _plot_metric_comparison(results, "precision", "recall", color_param)
|
||||
st.altair_chart(recall_precision_chart, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
st.subheader("Accuracy vs Jaccard")
|
||||
with st.spinner("Generating Accuracy vs Jaccard plot..."):
|
||||
accuracy_jaccard_chart = _plot_metric_comparison(results, "accuracy", "jaccard", color_param)
|
||||
st.altair_chart(accuracy_jaccard_chart, use_container_width=True)
|
||||
|
||||
# Optional: Raw data table
|
||||
with st.expander("View Raw Results Data"):
|
||||
st.dataframe(results, width="stretch")
|
||||
|
|
@ -993,164 +1109,226 @@ def main():
|
|||
|
||||
# Embedding features analysis (if present)
|
||||
if embedding_feature_array is not None:
|
||||
st.subheader(":artificial_satellite: Embedding Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of embedding features showing which aggregations, bands, and years
|
||||
are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
with st.container(border=True):
|
||||
st.header(":artificial_satellite: Embedding Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of embedding features showing which aggregations, bands, and years
|
||||
are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Summary bar charts
|
||||
st.markdown("### Importance by Dimension")
|
||||
with st.spinner("Generating dimension summaries..."):
|
||||
chart_agg, chart_band, chart_year = _plot_embedding_aggregation_summary(embedding_feature_array)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.altair_chart(chart_agg, use_container_width=True)
|
||||
with col2:
|
||||
st.altair_chart(chart_band, use_container_width=True)
|
||||
with col3:
|
||||
st.altair_chart(chart_year, use_container_width=True)
|
||||
# Summary bar charts
|
||||
st.markdown("### Importance by Dimension")
|
||||
with st.spinner("Generating dimension summaries..."):
|
||||
chart_agg, chart_band, chart_year = _plot_embedding_aggregation_summary(embedding_feature_array)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.altair_chart(chart_agg, use_container_width=True)
|
||||
with col2:
|
||||
st.altair_chart(chart_band, use_container_width=True)
|
||||
with col3:
|
||||
st.altair_chart(chart_year, use_container_width=True)
|
||||
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap by Aggregation")
|
||||
st.markdown("Shows the weight of each band-year combination for each aggregation type.")
|
||||
with st.spinner("Generating heatmap..."):
|
||||
heatmap_chart = _plot_embedding_heatmap(embedding_feature_array)
|
||||
st.altair_chart(heatmap_chart, use_container_width=True)
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap by Aggregation")
|
||||
st.markdown("Shows the weight of each band-year combination for each aggregation type.")
|
||||
with st.spinner("Generating heatmap..."):
|
||||
heatmap_chart = _plot_embedding_heatmap(embedding_feature_array)
|
||||
st.altair_chart(heatmap_chart, use_container_width=True)
|
||||
|
||||
# Statistics
|
||||
with st.expander("Embedding Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_emb_features = embedding_feature_array.size
|
||||
mean_weight = float(embedding_feature_array.mean().values)
|
||||
max_weight = float(embedding_feature_array.max().values)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total Embedding Features", n_emb_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
# Statistics
|
||||
with st.expander("Embedding Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_emb_features = embedding_feature_array.size
|
||||
mean_weight = float(embedding_feature_array.mean().values)
|
||||
max_weight = float(embedding_feature_array.max().values)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total Embedding Features", n_emb_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
|
||||
# Show top embedding features
|
||||
st.write("**Top 10 Embedding Features:**")
|
||||
emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index()
|
||||
top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]]
|
||||
st.dataframe(top_emb, width="stretch")
|
||||
# Show top embedding features
|
||||
st.write("**Top 10 Embedding Features:**")
|
||||
emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index()
|
||||
top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]]
|
||||
st.dataframe(top_emb, width="stretch")
|
||||
else:
|
||||
st.info("No embedding features found in this model.")
|
||||
|
||||
# ERA5 features analysis (if present)
|
||||
if era5_feature_array is not None:
|
||||
st.subheader(":partly_sunny: ERA5 Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of ERA5 climate features showing which variables and time periods
|
||||
are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
with st.container(border=True):
|
||||
st.header(":partly_sunny: ERA5 Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of ERA5 climate features showing which variables and time periods
|
||||
are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Summary bar charts
|
||||
st.markdown("### Importance by Dimension")
|
||||
with st.spinner("Generating ERA5 dimension summaries..."):
|
||||
chart_variable, chart_time = _plot_era5_summary(era5_feature_array)
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.altair_chart(chart_variable, use_container_width=True)
|
||||
with col2:
|
||||
st.altair_chart(chart_time, use_container_width=True)
|
||||
# Summary bar charts
|
||||
st.markdown("### Importance by Dimension")
|
||||
with st.spinner("Generating ERA5 dimension summaries..."):
|
||||
chart_variable, chart_time = _plot_era5_summary(era5_feature_array)
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.altair_chart(chart_variable, use_container_width=True)
|
||||
with col2:
|
||||
st.altair_chart(chart_time, use_container_width=True)
|
||||
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap")
|
||||
st.markdown("Shows the weight of each variable-time combination.")
|
||||
with st.spinner("Generating ERA5 heatmap..."):
|
||||
era5_heatmap_chart = _plot_era5_heatmap(era5_feature_array)
|
||||
st.altair_chart(era5_heatmap_chart, use_container_width=True)
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap")
|
||||
st.markdown("Shows the weight of each variable-time combination.")
|
||||
with st.spinner("Generating ERA5 heatmap..."):
|
||||
era5_heatmap_chart = _plot_era5_heatmap(era5_feature_array)
|
||||
st.altair_chart(era5_heatmap_chart, use_container_width=True)
|
||||
|
||||
# Statistics
|
||||
with st.expander("ERA5 Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_era5_features = era5_feature_array.size
|
||||
mean_weight = float(era5_feature_array.mean().values)
|
||||
max_weight = float(era5_feature_array.max().values)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total ERA5 Features", n_era5_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
# Statistics
|
||||
with st.expander("ERA5 Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_era5_features = era5_feature_array.size
|
||||
mean_weight = float(era5_feature_array.mean().values)
|
||||
max_weight = float(era5_feature_array.max().values)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total ERA5 Features", n_era5_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
|
||||
# Show top ERA5 features
|
||||
st.write("**Top 10 ERA5 Features:**")
|
||||
era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
|
||||
top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]]
|
||||
st.dataframe(top_era5, width="stretch")
|
||||
# Show top ERA5 features
|
||||
st.write("**Top 10 ERA5 Features:**")
|
||||
era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
|
||||
top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]]
|
||||
st.dataframe(top_era5, width="stretch")
|
||||
else:
|
||||
st.info("No ERA5 features found in this model.")
|
||||
|
||||
# Common features analysis (if present)
|
||||
if common_feature_array is not None:
|
||||
st.subheader(":world_map: Common Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of common features including cell area, water area, land area, land ratio,
|
||||
longitude, and latitude. These features provide spatial and geographic context.
|
||||
"""
|
||||
)
|
||||
with st.container(border=True):
|
||||
st.header(":world_map: Common Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of common features including cell area, water area, land area, land ratio,
|
||||
longitude, and latitude. These features provide spatial and geographic context.
|
||||
"""
|
||||
)
|
||||
|
||||
# Bar chart showing all common feature weights
|
||||
with st.spinner("Generating common features chart..."):
|
||||
common_chart = _plot_common_features(common_feature_array)
|
||||
st.altair_chart(common_chart, use_container_width=True)
|
||||
# Bar chart showing all common feature weights
|
||||
with st.spinner("Generating common features chart..."):
|
||||
common_chart = _plot_common_features(common_feature_array)
|
||||
st.altair_chart(common_chart, use_container_width=True)
|
||||
|
||||
# Statistics
|
||||
with st.expander("Common Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_common_features = common_feature_array.size
|
||||
mean_weight = float(common_feature_array.mean().values)
|
||||
max_weight = float(common_feature_array.max().values)
|
||||
min_weight = float(common_feature_array.min().values)
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
st.metric("Total Common Features", n_common_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
with col4:
|
||||
st.metric("Min Weight", f"{min_weight:.4f}")
|
||||
# Statistics
|
||||
with st.expander("Common Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_common_features = common_feature_array.size
|
||||
mean_weight = float(common_feature_array.mean().values)
|
||||
max_weight = float(common_feature_array.max().values)
|
||||
min_weight = float(common_feature_array.min().values)
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
st.metric("Total Common Features", n_common_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
with col4:
|
||||
st.metric("Min Weight", f"{min_weight:.4f}")
|
||||
|
||||
# Show all common features sorted by importance
|
||||
st.write("**All Common Features (by absolute weight):**")
|
||||
common_df = common_feature_array.to_dataframe(name="weight").reset_index()
|
||||
common_df["abs_weight"] = common_df["weight"].abs()
|
||||
common_df = common_df.sort_values("abs_weight", ascending=False)
|
||||
st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch")
|
||||
# Show all common features sorted by importance
|
||||
st.write("**All Common Features (by absolute weight):**")
|
||||
common_df = common_feature_array.to_dataframe(name="weight").reset_index()
|
||||
common_df["abs_weight"] = common_df["weight"].abs()
|
||||
common_df = common_df.sort_values("abs_weight", ascending=False)
|
||||
st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
**Interpretation:**
|
||||
- **cell_area, water_area, land_area**: Spatial extent features that may indicate size-related patterns
|
||||
- **land_ratio**: Proportion of land vs water in each cell
|
||||
- **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
|
||||
- Positive weights indicate features that increase the probability of the positive class
|
||||
- Negative weights indicate features that decrease the probability of the positive class
|
||||
"""
|
||||
)
|
||||
st.markdown(
|
||||
"""
|
||||
**Interpretation:**
|
||||
- **cell_area, water_area, land_area**: Spatial extent features that may indicate
|
||||
size-related patterns
|
||||
- **land_ratio**: Proportion of land vs water in each cell
|
||||
- **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
|
||||
- Positive weights indicate features that increase the probability of the positive class
|
||||
- Negative weights indicate features that decrease the probability of the positive class
|
||||
"""
|
||||
)
|
||||
else:
|
||||
st.info("No common features found in this model.")
|
||||
|
||||
with tab3:
|
||||
# Map visualization
|
||||
st.header("Predictions Map")
|
||||
st.markdown("Map showing predicted classes from the best estimator")
|
||||
with st.spinner("Generating map..."):
|
||||
# Inference analysis
|
||||
st.header("Inference Analysis")
|
||||
st.markdown("Comprehensive analysis of model predictions on the evaluation dataset")
|
||||
|
||||
# Summary statistics
|
||||
st.subheader("Prediction Statistics")
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
total_cells = len(predictions)
|
||||
st.metric("Total Cells", f"{total_cells:,}")
|
||||
|
||||
with col2:
|
||||
n_classes = predictions["predicted_class"].nunique()
|
||||
st.metric("Number of Classes", n_classes)
|
||||
|
||||
# Class distribution visualization
|
||||
st.subheader("Class Distribution")
|
||||
|
||||
with st.spinner("Generating class distribution..."):
|
||||
class_dist_chart = _plot_prediction_class_distribution(predictions)
|
||||
st.altair_chart(class_dist_chart, use_container_width=True)
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
**Interpretation:**
|
||||
- Shows the balance between predicted classes
|
||||
- Class imbalance may indicate regional patterns or model bias
|
||||
- Each bar represents the count of cells predicted for that class
|
||||
"""
|
||||
)
|
||||
|
||||
# Interactive map
|
||||
st.subheader("Interactive Prediction Map")
|
||||
st.markdown("Explore predictions spatially with the interactive map below")
|
||||
|
||||
with st.spinner("Generating interactive map..."):
|
||||
chart_map = _plot_prediction_map(predictions)
|
||||
st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[])
|
||||
|
||||
# Additional statistics in expander
|
||||
with st.expander("Detailed Prediction Statistics"):
|
||||
st.write("**Class Distribution:**")
|
||||
class_counts = predictions["predicted_class"].value_counts().sort_index()
|
||||
|
||||
# Create columns for better layout
|
||||
n_cols = min(5, len(class_counts))
|
||||
cols = st.columns(n_cols)
|
||||
|
||||
for idx, (class_label, count) in enumerate(class_counts.items()):
|
||||
percentage = count / len(predictions) * 100
|
||||
with cols[idx % n_cols]:
|
||||
st.metric(f"Class {class_label}", f"{count:,} ({percentage:.2f}%)")
|
||||
|
||||
# Show detailed table
|
||||
st.write("**Detailed Class Breakdown:**")
|
||||
class_df = pd.DataFrame(
|
||||
{
|
||||
"Class": class_counts.index,
|
||||
"Count": class_counts.to_numpy(),
|
||||
"Percentage": (class_counts.to_numpy() / len(predictions) * 100).round(2),
|
||||
}
|
||||
)
|
||||
st.dataframe(class_df, width="stretch", hide_index=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue