Switch to predict_classes
This commit is contained in:
parent
d4a747d800
commit
b2cfddfead
3 changed files with 373 additions and 192 deletions
|
|
@ -18,13 +18,14 @@ pretty.install()
|
||||||
set_config(array_api_dispatch=True)
|
set_config(array_api_dispatch=True)
|
||||||
|
|
||||||
|
|
||||||
def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier):
|
def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier, classes: list) -> gpd.GeoDataFrame:
|
||||||
"""Get predicted probabilities for each cell.
|
"""Get predicted probabilities for each cell.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||||
level (int): The grid level to use.
|
level (int): The grid level to use.
|
||||||
clf (ESPAClassifier): The trained classifier to use for predictions.
|
clf (ESPAClassifier): The trained classifier to use for predictions.
|
||||||
|
classes (list): List of class names.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: A list of predicted probabilities for each cell.
|
list: A list of predicted probabilities for each cell.
|
||||||
|
|
@ -46,11 +47,11 @@ def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifi
|
||||||
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
|
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
|
||||||
X_batch = X_batch.to_numpy(dtype="float32")
|
X_batch = X_batch.to_numpy(dtype="float32")
|
||||||
X_batch = torch.asarray(X_batch, device=0)
|
X_batch = torch.asarray(X_batch, device=0)
|
||||||
batch_preds = clf.predict_proba(X_batch)[:, 1].cpu().numpy()
|
batch_preds = clf.predict(X_batch).cpu().numpy()
|
||||||
batch_preds = gpd.GeoDataFrame(
|
batch_preds = gpd.GeoDataFrame(
|
||||||
{
|
{
|
||||||
"cell_id": cell_ids,
|
"cell_id": cell_ids,
|
||||||
"predicted_proba": batch_preds,
|
"predicted_class": [classes[i] for i in batch_preds],
|
||||||
"geometry": cell_geoms,
|
"geometry": cell_geoms,
|
||||||
},
|
},
|
||||||
crs="epsg:3413",
|
crs="epsg:3413",
|
||||||
|
|
|
||||||
|
|
@ -29,13 +29,14 @@ pretty.install()
|
||||||
set_config(array_api_dispatch=True)
|
set_config(array_api_dispatch=True)
|
||||||
|
|
||||||
|
|
||||||
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, robust: bool = False):
|
||||||
"""Perform random cross-validation on the training dataset.
|
"""Perform random cross-validation on the training dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||||
level (int): The grid level to use.
|
level (int): The grid level to use.
|
||||||
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
||||||
|
robust (bool, optional): Whether to use robust training. Defaults to False.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
data = get_train_dataset_file(grid=grid, level=level)
|
data = get_train_dataset_file(grid=grid, level=level)
|
||||||
|
|
@ -59,7 +60,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
||||||
"initial_K": randint(20, 400),
|
"initial_K": randint(20, 400),
|
||||||
}
|
}
|
||||||
|
|
||||||
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42)
|
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
|
||||||
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
||||||
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
|
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
|
||||||
search = RandomizedSearchCV(
|
search = RandomizedSearchCV(
|
||||||
|
|
@ -69,7 +70,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
||||||
n_jobs=16,
|
n_jobs=16,
|
||||||
cv=cv,
|
cv=cv,
|
||||||
random_state=42,
|
random_state=42,
|
||||||
verbose=10,
|
verbose=3,
|
||||||
scoring=metrics,
|
scoring=metrics,
|
||||||
refit="f1",
|
refit="f1",
|
||||||
)
|
)
|
||||||
|
|
@ -114,6 +115,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
||||||
},
|
},
|
||||||
"cv_splits": cv.get_n_splits(),
|
"cv_splits": cv.get_n_splits(),
|
||||||
"metrics": metrics,
|
"metrics": metrics,
|
||||||
|
"classes": ["No RTS", "RTS"],
|
||||||
}
|
}
|
||||||
settings_file = results_dir / "search_settings.toml"
|
settings_file = results_dir / "search_settings.toml"
|
||||||
print(f"Storing search settings to {settings_file}")
|
print(f"Storing search settings to {settings_file}")
|
||||||
|
|
@ -183,7 +185,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
||||||
|
|
||||||
# Predict probabilities for all cells
|
# Predict probabilities for all cells
|
||||||
print("Predicting probabilities for all cells...")
|
print("Predicting probabilities for all cells...")
|
||||||
preds = predict_proba(grid=grid, level=level, clf=best_estimator)
|
preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=settings["classes"])
|
||||||
preds_file = results_dir / "predicted_probabilities.parquet"
|
preds_file = results_dir / "predicted_probabilities.parquet"
|
||||||
print(f"Storing predicted probabilities to {preds_file}")
|
print(f"Storing predicted probabilities to {preds_file}")
|
||||||
preds.to_parquet(preds_file)
|
preds.to_parquet(preds_file)
|
||||||
|
|
|
||||||
|
|
@ -36,11 +36,12 @@ def get_available_result_files() -> list[Path]:
|
||||||
return sorted(result_files, reverse=True) # Most recent first
|
return sorted(result_files, reverse=True) # Most recent first
|
||||||
|
|
||||||
|
|
||||||
def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame:
|
def load_and_prepare_results(file_path: Path, settings: dict, k_bin_width: int = 40) -> pd.DataFrame:
|
||||||
"""Load results file and prepare binned columns.
|
"""Load results file and prepare binned columns.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: Path to the results parquet file.
|
file_path: Path to the results parquet file.
|
||||||
|
settings: Dictionary of search settings.
|
||||||
k_bin_width: Width of bins for initial_K parameter.
|
k_bin_width: Width of bins for initial_K parameter.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -50,21 +51,21 @@ def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataF
|
||||||
results = pd.read_parquet(file_path)
|
results = pd.read_parquet(file_path)
|
||||||
|
|
||||||
# Automatically determine bin width for initial_K based on data range
|
# Automatically determine bin width for initial_K based on data range
|
||||||
k_min = results["initial_K"].min()
|
k_min = settings["param_grid"]["initial_K"]["low"]
|
||||||
k_max = results["initial_K"].max()
|
k_max = settings["param_grid"]["initial_K"]["high"]
|
||||||
# Use configurable bin width, adapted to actual data range
|
# Use configurable bin width, adapted to actual data range
|
||||||
k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width)
|
k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width)
|
||||||
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False)
|
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False)
|
||||||
|
|
||||||
# Automatically create logarithmic bins for epsilon parameters based on data range
|
# Automatically create logarithmic bins for epsilon parameters based on data range
|
||||||
# Use 10 bins spanning the actual data range
|
# Use 10 bins spanning the actual data range
|
||||||
eps_cl_min = np.log10(results["eps_cl"].min())
|
eps_cl_min = np.log10(settings["param_grid"]["eps_cl"]["low"])
|
||||||
eps_cl_max = np.log10(results["eps_cl"].max())
|
eps_cl_max = np.log10(settings["param_grid"]["eps_cl"]["high"])
|
||||||
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10)
|
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=int(eps_cl_max - eps_cl_min + 1))
|
||||||
|
|
||||||
eps_e_min = np.log10(results["eps_e"].min())
|
eps_e_min = np.log10(settings["param_grid"]["eps_e"]["low"])
|
||||||
eps_e_max = np.log10(results["eps_e"].max())
|
eps_e_max = np.log10(settings["param_grid"]["eps_e"]["high"])
|
||||||
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10)
|
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=int(eps_e_max - eps_e_min + 1))
|
||||||
|
|
||||||
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
|
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
|
||||||
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
|
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
|
||||||
|
|
@ -184,29 +185,46 @@ def extract_common_features(model_state: xr.Dataset) -> xr.DataArray | None:
|
||||||
return common_feature_array
|
return common_feature_array
|
||||||
|
|
||||||
|
|
||||||
def _extract_probs_as_xdggs(preds: gpd.GeoDataFrame, settings: dict) -> xr.DataArray:
|
|
||||||
grid = settings["grid"]
|
|
||||||
level = settings["level"]
|
|
||||||
|
|
||||||
probs = xr.DataArray(
|
|
||||||
data=preds["predicted_proba"].to_numpy(),
|
|
||||||
dims=["cell_ids"],
|
|
||||||
coords={"cell_ids": preds["cell_id"].to_numpy()},
|
|
||||||
)
|
|
||||||
gridinfo = {
|
|
||||||
"grid_name": "h3" if grid == "hex" else grid,
|
|
||||||
"level": level,
|
|
||||||
}
|
|
||||||
if grid == "healpix":
|
|
||||||
gridinfo["indexing_scheme"] = "nested"
|
|
||||||
probs.cell_ids.attrs = gridinfo
|
|
||||||
probs = xdggs.decode(probs)
|
|
||||||
return probs
|
|
||||||
|
|
||||||
|
|
||||||
def _plot_prediction_map(preds: gpd.GeoDataFrame):
|
def _plot_prediction_map(preds: gpd.GeoDataFrame):
|
||||||
return preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron")
|
return preds.explore(column="predicted_class", cmap="Set3", legend=True, tiles="CartoDB positron")
|
||||||
# return probs.dggs.explore()
|
|
||||||
|
|
||||||
|
def _plot_prediction_class_distribution(preds: gpd.GeoDataFrame):
|
||||||
|
"""Create a bar chart showing the count of each predicted class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds: GeoDataFrame with predicted classes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Altair chart showing class distribution.
|
||||||
|
|
||||||
|
"""
|
||||||
|
df = pd.DataFrame({"predicted_class": preds["predicted_class"].to_numpy()})
|
||||||
|
counts = df["predicted_class"].value_counts().reset_index()
|
||||||
|
counts.columns = ["class", "count"]
|
||||||
|
counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2)
|
||||||
|
|
||||||
|
chart = (
|
||||||
|
alt.Chart(counts)
|
||||||
|
.mark_bar()
|
||||||
|
.encode(
|
||||||
|
x=alt.X("class:N", title="Predicted Class"),
|
||||||
|
y=alt.Y("count:Q", title="Number of Cells"),
|
||||||
|
color=alt.Color("class:N", title="Class", legend=None),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip("class:N", title="Class"),
|
||||||
|
alt.Tooltip("count:Q", title="Count"),
|
||||||
|
alt.Tooltip("percentage:Q", format=".2f", title="Percentage (%)"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(
|
||||||
|
width=400,
|
||||||
|
height=300,
|
||||||
|
title="Predicted Class Distribution",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chart
|
||||||
|
|
||||||
|
|
||||||
def _plot_k_binned(
|
def _plot_k_binned(
|
||||||
|
|
@ -740,6 +758,55 @@ def _plot_common_features(common_array: xr.DataArray):
|
||||||
return chart
|
return chart
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str, color_param: str):
|
||||||
|
"""Create a scatter plot comparing two metrics with parameter-based coloring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: DataFrame containing the results with metric columns.
|
||||||
|
x_metric: Name of the metric to plot on x-axis (e.g., 'precision', 'accuracy').
|
||||||
|
y_metric: Name of the metric to plot on y-axis (e.g., 'recall', 'jaccard').
|
||||||
|
color_param: Parameter to use for coloring ('initial_K', 'eps_cl', or 'eps_e').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Altair chart showing the metric comparison.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Determine color scale based on parameter
|
||||||
|
if color_param in ["eps_cl", "eps_e"]:
|
||||||
|
color_scale = alt.Scale(type="log", scheme="viridis")
|
||||||
|
else:
|
||||||
|
color_scale = alt.Scale(scheme="viridis")
|
||||||
|
|
||||||
|
chart = (
|
||||||
|
alt.Chart(results)
|
||||||
|
.mark_circle(size=60, opacity=0.7)
|
||||||
|
.encode(
|
||||||
|
x=alt.X(f"mean_test_{x_metric}:Q", title=x_metric.capitalize(), scale=alt.Scale(zero=False)),
|
||||||
|
y=alt.Y(f"mean_test_{y_metric}:Q", title=y_metric.capitalize(), scale=alt.Scale(zero=False)),
|
||||||
|
color=alt.Color(
|
||||||
|
f"{color_param}:Q",
|
||||||
|
scale=color_scale,
|
||||||
|
title=color_param,
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip(f"mean_test_{x_metric}:Q", format=".4f", title=x_metric.capitalize()),
|
||||||
|
alt.Tooltip(f"mean_test_{y_metric}:Q", format=".4f", title=y_metric.capitalize()),
|
||||||
|
alt.Tooltip("mean_test_f1:Q", format=".4f", title="F1"),
|
||||||
|
alt.Tooltip("initial_K:Q", format=".0f", title="initial_K"),
|
||||||
|
alt.Tooltip("eps_cl:Q", format=".2e", title="eps_cl"),
|
||||||
|
alt.Tooltip("eps_e:Q", format=".2e", title="eps_e"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(
|
||||||
|
width=400,
|
||||||
|
height=400,
|
||||||
|
)
|
||||||
|
.interactive()
|
||||||
|
)
|
||||||
|
|
||||||
|
return chart
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Run Streamlit dashboard application."""
|
"""Run Streamlit dashboard application."""
|
||||||
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
|
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
|
||||||
|
|
@ -769,7 +836,8 @@ def main():
|
||||||
|
|
||||||
# Load and prepare data with default bin width (will be reloaded with custom width later)
|
# Load and prepare data with default bin width (will be reloaded with custom width later)
|
||||||
with st.spinner("Loading results..."):
|
with st.spinner("Loading results..."):
|
||||||
results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=40)
|
settings = toml.load(results_dir / "search_settings.toml")["settings"]
|
||||||
|
results = load_and_prepare_results(results_dir / "search_results.parquet", settings, k_bin_width=40)
|
||||||
model_state = load_and_prepare_model_state(results_dir / "best_estimator_state.nc")
|
model_state = load_and_prepare_model_state(results_dir / "best_estimator_state.nc")
|
||||||
n_features = model_state.sizes["feature"]
|
n_features = model_state.sizes["feature"]
|
||||||
model_state["feature_weights"] *= n_features
|
model_state["feature_weights"] *= n_features
|
||||||
|
|
@ -777,29 +845,55 @@ def main():
|
||||||
era5_feature_array = extract_era5_features(model_state)
|
era5_feature_array = extract_era5_features(model_state)
|
||||||
common_feature_array = extract_common_features(model_state)
|
common_feature_array = extract_common_features(model_state)
|
||||||
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
|
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
|
||||||
settings = toml.load(results_dir / "search_settings.toml")["settings"]
|
|
||||||
probs = _extract_probs_as_xdggs(predictions, settings)
|
|
||||||
|
|
||||||
st.sidebar.success(f"Loaded {len(results)} results")
|
st.sidebar.success(f"Loaded {len(results)} results")
|
||||||
|
|
||||||
# Metric selection
|
|
||||||
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
|
|
||||||
selected_metric = st.sidebar.selectbox(
|
|
||||||
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Display some basic statistics first (lightweight)
|
# Display some basic statistics first (lightweight)
|
||||||
st.header("Dataset Overview")
|
st.header("Parameter-Search Overview")
|
||||||
col1, col2, col3 = st.columns(3)
|
|
||||||
|
# Show total runs and best model info
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
with col1:
|
with col1:
|
||||||
st.metric("Total Runs", len(results))
|
st.metric("Total Runs", len(results))
|
||||||
|
|
||||||
with col2:
|
with col2:
|
||||||
best_score = results[f"mean_test_{selected_metric}"].max()
|
# Best model based on F1 score
|
||||||
st.metric(f"Best {selected_metric.capitalize()}", f"{best_score:.4f}")
|
best_f1_idx = results["mean_test_f1"].idxmax()
|
||||||
|
st.metric("Best Model Index (by F1)", f"#{best_f1_idx}")
|
||||||
|
|
||||||
|
# Show best parameters for the best model
|
||||||
|
with st.expander("Best Model Parameters (by F1)", expanded=True):
|
||||||
|
best_params = results.loc[best_f1_idx, ["initial_K", "eps_cl", "eps_e", "mean_test_f1", "std_test_f1"]]
|
||||||
|
col1, col2, col3, col4, col5 = st.columns(5)
|
||||||
|
with col1:
|
||||||
|
st.metric("initial_K", f"{best_params['initial_K']:.0f}")
|
||||||
|
with col2:
|
||||||
|
st.metric("eps_cl", f"{best_params['eps_cl']:.2e}")
|
||||||
with col3:
|
with col3:
|
||||||
best_idx = results[f"mean_test_{selected_metric}"].idxmax()
|
st.metric("eps_e", f"{best_params['eps_e']:.2e}")
|
||||||
best_k = results.loc[best_idx, "initial_K"]
|
with col4:
|
||||||
st.metric("Best K", f"{best_k:.0f}")
|
st.metric("F1 Score", f"{best_params['mean_test_f1']:.4f}")
|
||||||
|
with col5:
|
||||||
|
st.metric("F1 Std", f"{best_params['std_test_f1']:.4f}")
|
||||||
|
|
||||||
|
# Show all metrics
|
||||||
|
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
|
||||||
|
cols = st.columns(len(available_metrics))
|
||||||
|
|
||||||
|
for idx, metric in enumerate(available_metrics):
|
||||||
|
with cols[idx]:
|
||||||
|
best_score = results.loc[best_f1_idx, f"mean_test_{metric}"]
|
||||||
|
st.metric(f"{metric.capitalize()}", f"{best_score:.4f}")
|
||||||
|
|
||||||
|
# Create tabs for different visualizations
|
||||||
|
tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Inference Analysis"])
|
||||||
|
|
||||||
|
with tab1:
|
||||||
|
# Metric selection - only used in this tab
|
||||||
|
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
|
||||||
|
selected_metric = st.selectbox(
|
||||||
|
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
|
||||||
|
)
|
||||||
|
|
||||||
# Show best parameters
|
# Show best parameters
|
||||||
with st.expander("Best Parameters"):
|
with st.expander("Best Parameters"):
|
||||||
|
|
@ -807,10 +901,6 @@ def main():
|
||||||
best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]]
|
best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]]
|
||||||
st.dataframe(best_params.to_frame().T, width="content")
|
st.dataframe(best_params.to_frame().T, width="content")
|
||||||
|
|
||||||
# Create tabs for different visualizations
|
|
||||||
tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Predictions Map"])
|
|
||||||
|
|
||||||
with tab1:
|
|
||||||
# Main plots
|
# Main plots
|
||||||
st.header(f"Visualization for {selected_metric.capitalize()}")
|
st.header(f"Visualization for {selected_metric.capitalize()}")
|
||||||
|
|
||||||
|
|
@ -847,7 +937,9 @@ def main():
|
||||||
# Reload data if bin width changed from default
|
# Reload data if bin width changed from default
|
||||||
if k_bin_width != 40:
|
if k_bin_width != 40:
|
||||||
with st.spinner("Re-binning data..."):
|
with st.spinner("Re-binning data..."):
|
||||||
results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=k_bin_width)
|
results = load_and_prepare_results(
|
||||||
|
results_dir / "search_results.parquet", settings, k_bin_width=k_bin_width
|
||||||
|
)
|
||||||
|
|
||||||
# K-binned plots
|
# K-binned plots
|
||||||
col1, col2 = st.columns(2)
|
col1, col2 = st.columns(2)
|
||||||
|
|
@ -885,6 +977,30 @@ def main():
|
||||||
chart4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}")
|
chart4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}")
|
||||||
st.altair_chart(chart4, use_container_width=True)
|
st.altair_chart(chart4, use_container_width=True)
|
||||||
|
|
||||||
|
# Metric comparison plots
|
||||||
|
st.header("Metric Comparisons")
|
||||||
|
|
||||||
|
# Color parameter selection
|
||||||
|
color_param = st.selectbox(
|
||||||
|
"Select Color Parameter",
|
||||||
|
options=["initial_K", "eps_cl", "eps_e"],
|
||||||
|
help="Choose which parameter to use for coloring the scatter plots",
|
||||||
|
)
|
||||||
|
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.subheader("Recall vs Precision")
|
||||||
|
with st.spinner("Generating Recall vs Precision plot..."):
|
||||||
|
recall_precision_chart = _plot_metric_comparison(results, "precision", "recall", color_param)
|
||||||
|
st.altair_chart(recall_precision_chart, use_container_width=True)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.subheader("Accuracy vs Jaccard")
|
||||||
|
with st.spinner("Generating Accuracy vs Jaccard plot..."):
|
||||||
|
accuracy_jaccard_chart = _plot_metric_comparison(results, "accuracy", "jaccard", color_param)
|
||||||
|
st.altair_chart(accuracy_jaccard_chart, use_container_width=True)
|
||||||
|
|
||||||
# Optional: Raw data table
|
# Optional: Raw data table
|
||||||
with st.expander("View Raw Results Data"):
|
with st.expander("View Raw Results Data"):
|
||||||
st.dataframe(results, width="stretch")
|
st.dataframe(results, width="stretch")
|
||||||
|
|
@ -993,7 +1109,8 @@ def main():
|
||||||
|
|
||||||
# Embedding features analysis (if present)
|
# Embedding features analysis (if present)
|
||||||
if embedding_feature_array is not None:
|
if embedding_feature_array is not None:
|
||||||
st.subheader(":artificial_satellite: Embedding Feature Analysis")
|
with st.container(border=True):
|
||||||
|
st.header(":artificial_satellite: Embedding Feature Analysis")
|
||||||
st.markdown(
|
st.markdown(
|
||||||
"""
|
"""
|
||||||
Analysis of embedding features showing which aggregations, bands, and years
|
Analysis of embedding features showing which aggregations, bands, and years
|
||||||
|
|
@ -1044,7 +1161,8 @@ def main():
|
||||||
|
|
||||||
# ERA5 features analysis (if present)
|
# ERA5 features analysis (if present)
|
||||||
if era5_feature_array is not None:
|
if era5_feature_array is not None:
|
||||||
st.subheader(":partly_sunny: ERA5 Feature Analysis")
|
with st.container(border=True):
|
||||||
|
st.header(":partly_sunny: ERA5 Feature Analysis")
|
||||||
st.markdown(
|
st.markdown(
|
||||||
"""
|
"""
|
||||||
Analysis of ERA5 climate features showing which variables and time periods
|
Analysis of ERA5 climate features showing which variables and time periods
|
||||||
|
|
@ -1093,7 +1211,8 @@ def main():
|
||||||
|
|
||||||
# Common features analysis (if present)
|
# Common features analysis (if present)
|
||||||
if common_feature_array is not None:
|
if common_feature_array is not None:
|
||||||
st.subheader(":world_map: Common Feature Analysis")
|
with st.container(border=True):
|
||||||
|
st.header(":world_map: Common Feature Analysis")
|
||||||
st.markdown(
|
st.markdown(
|
||||||
"""
|
"""
|
||||||
Analysis of common features including cell area, water area, land area, land ratio,
|
Analysis of common features including cell area, water area, land area, land ratio,
|
||||||
|
|
@ -1133,7 +1252,8 @@ def main():
|
||||||
st.markdown(
|
st.markdown(
|
||||||
"""
|
"""
|
||||||
**Interpretation:**
|
**Interpretation:**
|
||||||
- **cell_area, water_area, land_area**: Spatial extent features that may indicate size-related patterns
|
- **cell_area, water_area, land_area**: Spatial extent features that may indicate
|
||||||
|
size-related patterns
|
||||||
- **land_ratio**: Proportion of land vs water in each cell
|
- **land_ratio**: Proportion of land vs water in each cell
|
||||||
- **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
|
- **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
|
||||||
- Positive weights indicate features that increase the probability of the positive class
|
- Positive weights indicate features that increase the probability of the positive class
|
||||||
|
|
@ -1144,13 +1264,71 @@ def main():
|
||||||
st.info("No common features found in this model.")
|
st.info("No common features found in this model.")
|
||||||
|
|
||||||
with tab3:
|
with tab3:
|
||||||
# Map visualization
|
# Inference analysis
|
||||||
st.header("Predictions Map")
|
st.header("Inference Analysis")
|
||||||
st.markdown("Map showing predicted classes from the best estimator")
|
st.markdown("Comprehensive analysis of model predictions on the evaluation dataset")
|
||||||
with st.spinner("Generating map..."):
|
|
||||||
|
# Summary statistics
|
||||||
|
st.subheader("Prediction Statistics")
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
total_cells = len(predictions)
|
||||||
|
st.metric("Total Cells", f"{total_cells:,}")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
n_classes = predictions["predicted_class"].nunique()
|
||||||
|
st.metric("Number of Classes", n_classes)
|
||||||
|
|
||||||
|
# Class distribution visualization
|
||||||
|
st.subheader("Class Distribution")
|
||||||
|
|
||||||
|
with st.spinner("Generating class distribution..."):
|
||||||
|
class_dist_chart = _plot_prediction_class_distribution(predictions)
|
||||||
|
st.altair_chart(class_dist_chart, use_container_width=True)
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
**Interpretation:**
|
||||||
|
- Shows the balance between predicted classes
|
||||||
|
- Class imbalance may indicate regional patterns or model bias
|
||||||
|
- Each bar represents the count of cells predicted for that class
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Interactive map
|
||||||
|
st.subheader("Interactive Prediction Map")
|
||||||
|
st.markdown("Explore predictions spatially with the interactive map below")
|
||||||
|
|
||||||
|
with st.spinner("Generating interactive map..."):
|
||||||
chart_map = _plot_prediction_map(predictions)
|
chart_map = _plot_prediction_map(predictions)
|
||||||
st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[])
|
st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[])
|
||||||
|
|
||||||
|
# Additional statistics in expander
|
||||||
|
with st.expander("Detailed Prediction Statistics"):
|
||||||
|
st.write("**Class Distribution:**")
|
||||||
|
class_counts = predictions["predicted_class"].value_counts().sort_index()
|
||||||
|
|
||||||
|
# Create columns for better layout
|
||||||
|
n_cols = min(5, len(class_counts))
|
||||||
|
cols = st.columns(n_cols)
|
||||||
|
|
||||||
|
for idx, (class_label, count) in enumerate(class_counts.items()):
|
||||||
|
percentage = count / len(predictions) * 100
|
||||||
|
with cols[idx % n_cols]:
|
||||||
|
st.metric(f"Class {class_label}", f"{count:,} ({percentage:.2f}%)")
|
||||||
|
|
||||||
|
# Show detailed table
|
||||||
|
st.write("**Detailed Class Breakdown:**")
|
||||||
|
class_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"Class": class_counts.index,
|
||||||
|
"Count": class_counts.to_numpy(),
|
||||||
|
"Percentage": (class_counts.to_numpy() / len(predictions) * 100).round(2),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
st.dataframe(class_df, width="stretch", hide_index=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue