Enhance training analysis page with test metrics and confusion matrix
- Added a section to display test metrics for model performance on the held-out test set. - Implemented confusion matrix visualization to analyze prediction breakdown. - Refactored sidebar settings to streamline metric selection and improve user experience. - Updated cross-validation statistics to compare CV performance with test metrics. - Enhanced DatasetEnsemble methods to handle empty data scenarios gracefully. - Introduced debug scripts to assist in identifying feature mismatches and validating dataset preparation. - Added comprehensive tests for DatasetEnsemble to ensure feature consistency and correct behavior across various scenarios.
This commit is contained in:
parent
4fecac535c
commit
c92e856c55
23 changed files with 1845 additions and 484 deletions
223
pixi.lock
generated
223
pixi.lock
generated
|
|
@ -473,16 +473,16 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/69/ce/68d6e31f0a75a5cccc03535e47434c0ca4be37fe950e93117e455cbc362c/antimeridian-0.4.5-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/5b/03/c17464bbf682ea87e7e3de2ddc63395e359a78ae9c01f55fc78759ecbd79/anywidget-0.9.21-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e0/b1/0542e0cab6f49f151a2d7a42400f84f706fc0b64e85dc1f56708b2e9fd37/array_api_compat-1.12.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/1d/05/2709750ddb088eb2fc5053ba214b4f54334d15d4cb28217e2956b5507bac/array_api_extra-0.9.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/df/5d/493b1b5528ab5072feae30821ff3a07b7a0474213d548efb1fdf135f85c1/array_api_compat-1.13.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/21/2b/bfa1cfe370dd4ed51f834f2c6ad93b7f6263b83615ab96ad91094cc98ec6/array_api_extra-0.9.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/fb/1f/2903ef412cb82ba1f2211692f7339fd7c5aeb2764f2a97f0b6a9a18bbf52/arro3_compute-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/31/4a/72dc383d1a0d14f1d453e334e3461e229762edb1bf3f75b3ab977e9386ed/arro3_core-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/1b/df/2a5a1306dc1699b51b02c1c38c55f3564a8c4f84087c23c61e7e7ae37dfa/arro3_io-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c3/1c/f06ad85180e7dd9855aa5ede901bfc2be858d7bee17d4e978a14c0ecec14/astropy-7.2.0-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/1f/07/50501947849e780cb5580ebcd7af08c14d431640562e18a8ac2b055c90ec/astropy_iers_data-0.2025.12.22.0.40.30-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/57/61/2d06c08f022c9b617b79f6c55d88e596c1795a1d211e6bf584ac4b9e9506/astropy_iers_data-0.2026.1.5.0.43.43-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/ee/34/a9914e676971a13d6cc671b1ed172f9804b50a3a80a143ff196e52f4c7ee/azure_core-1.37.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/3d/9e/1c90a122ea6180e8c72eb7294adc92531b0e08eb3d2324c2ba70d37f4802/azure_storage_blob-12.27.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/d8/3a/6ef2047a072e54e1142718d433d50e9514c999a58f51abfff7902f3a72f8/azure_storage_blob-12.28.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/96/9a/663251dfb35aaddcbdbef78802ea5a9d3fad9d5fadde8774eacd9e1bfbb7/boost_histogram-1.6.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/3c/56/f47a80254ed4991cce9a2f6d8ae8aafbc8df1c3270e966b2927289e5a12f/boto3-1.41.5-py3-none-any.whl
|
||||
|
|
@ -494,12 +494,11 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/27/27/6414b1b7e5e151300c54e28ad1cf3e3b34fe66dc3256a989b031166b1ba3/cdshealpix-0.7.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a3/8f/c42a98f933022c7de00142526c9b6b7429fdcd0fc66c952b4ebbf0ff3b7f/cf_xarray-0.10.10-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/ba/08/52f06ff2f04d376f9cd2c211aefcf2b37f1978e43289341f362fc99f6a0e/cftime-1.6.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/3d/9a/2abecb28ae875e39c8cad711eb1186d8d14eab564705325e77e4e6ab9ae5/click_plugins-1.1.1.2-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/73/86/43fa9f15c5b9fb6e82620428827cd3c284aa933431405d1bcf5231ae3d3e/cligj-0.7.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c9/56/e7e69b427c3878352c2fb9b450bd0e19ed552753491d39d7d0a2f5226d41/cryptography-46.0.3-cp311-abi3-manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/fa/25/0be9314cd72fe2ee2ef89ceb1f438bc156428a12177d684040456eee4a56/cupy_xarray-0.1.4-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/8d/05/8efadba80e1296526e69c1dceba8b0f0bc3756e8d69f6ed9b0e647cf3169/cyclopts-4.4.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/20/5b/0eceb9a5990de9025733a0d212ca43649ba9facd58b8552b6bf93c11439d/cyclopts-4.4.4-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/25/3e/e27078370414ef35fafad2c06d182110073daaeb5d3bf734b0b1eeefe452/debugpy-1.8.19-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl
|
||||
|
|
@ -522,10 +521,10 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/31/b3/802576f2ea5dcb48501bb162e4c7b7b3ca5654a42b2c968ef98a797a4c79/geographiclib-2.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e5/15/cf2a69ade4b194aa524ac75112d5caac37414b20a3a03e6865dfe0bd1539/geopy-2.4.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/01/61/d4b89fec821f72385526e1b9d9a3a0385dda4a72b206d28049e2c7cd39b8/gitpython-3.1.45-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/ed/d4/90197b416cb61cefd316964fd9e7bd8324bcbafabf40eef14a9f20b81974/google_api_core-2.28.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/96/58/c1e716be1b055b504d80db2c8413f6c6a890a6ae218a65f178b63bc30356/google_api_python_client-2.187.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c6/97/451d55e05487a5cd6279a01a7e34921858b16f7dc8aa38a2c684743cd2b3/google_auth-2.45.0-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/db/18/79e9008530b79527e0d5f79e7eef08d3b179b7f851cfd3a2f27822fbdfa9/google_auth-2.47.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/99/d5/3c97526c8796d3caf5f4b3bed2b05e8a7102326f00a334e7a438237f3b22/google_auth_httplib2-0.3.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/89/20/bfa472e327c8edee00f04beecc80baeddd2ab33ee0e86fd7654da49d45e9/google_cloud_core-2.5.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/2d/80/6e5c7c83cea15ed4dfc4843b9df9db0716bc551ac938f7b5dd18a72bd5e4/google_cloud_storage-3.7.0-py3-none-any.whl
|
||||
|
|
@ -537,13 +536,14 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/d6/49/1f35189c1ca136b2f041b72402f2eb718bdcb435d9e88729fe6f6909c45d/h5netcdf-1.7.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/d9/69/4402ea66272dacc10b298cca18ed73e1c0791ff2ae9ed218d3859f9698ac/h5py-3.15.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/8c/a2/0d269db0f6163be503775dc8b6a6fa15820cc9fdc866f6ba608d86b721f2/httplib2-0.31.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/94/56/c5e8db63ba0e27b310a0b4c384da555b361741e7d186044d31f400c0419e/icechunk-1.1.14-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/8c/d7/db466e07a21553441adbf915f0913a3f8fecece364cacb2392f11be267be/icechunk-1.1.15-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/4c/0f/b66d63d4a5426c09005d3713b056e634e00e69788fdc88d1ffe40e5b7654/ipycytoscape-1.3.3-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/ca/d3/642a6dc3db8ea558a9b5fbc83815b197861868dc98f98a789b85c7660670/ipyevents-2.0.4-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/00/60/249e3444fcd9c833704741769981cd02fe2c7ce94126b1394e7a3b26e543/ipyfilechooser-0.6.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a3/17/20c2552266728ceba271967b87919664ecc0e33efca29c3efc6baf88c5f9/ipykernel-7.1.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/49/69/e9858f2c0b99bf9f036348d1c84b8026f438bb6875effe6a9bcd9883dada/ipyleaflet-0.20.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/f1/df/8ee1c5dd1e3308b5d5b2f2dfea323bb2f3827da8d654abb6642051199049/ipython-9.8.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/86/92/162cfaee4ccf370465c5af1ce36a9eacec1becb552f2033bb3584e6f640a/ipython-9.9.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/56/6d/0d9848617b9f753b87f214f1c682592f7ca42de085f564352f10f0843026/ipywidgets-8.1.8-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/15/aa/0aca39a37d3c7eb941ba736ede56d689e7be91cab5d9ca846bde3999eba6/isodate-0.7.2-py3-none-any.whl
|
||||
|
|
@ -558,11 +558,11 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/93/cf/be4e93afbfa0def2cd6fac9302071db0bd6d0617999ecbf53f92b9398de3/multiurl-0.3.7-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/97/1a/78b19893197ed7525edfa7f124a461626541e82aec694a468ba97755c24e/netcdf4-1.7.3-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/7b/7a/a8d32501bb95ecff342004a674720164f95ad616f269450b3bc13dc88ae3/netcdf4-1.7.4-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/ae/d3/ff8f1b9968aa4dcd1da1880322ed492314cc920998182e549b586c895a17/numbagg-0.9.4-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c4/e6/d359fdd37498e74d26a167f7a51e54542e642ea47181eb4e643a69a066c3/numcodecs-0.16.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/b0/e0/760e73c111193db5ca37712a148e4807d1b0c60302ab31e4ada6528ca34d/numpy_groupies-0.11.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/4a/4e/44dbb46b3d1b0ec61afda8e84837870f2f9ace33c564317d59b70bc19d3e/nvidia_nccl_cu12-2.28.9-py3-none-manylinux_2_18_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/23/2d/609d0392d992259c6dc39881688a7fc13b1397a668bc360fbd68d1396f85/nvidia_nccl_cu12-2.29.2-py3-none-manylinux_2_18_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/99/e2/311fb383d9534eef7bfbe858fad931b6e3dbe85843c50592f50063c3bc83/odc_geo-0.4.10-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/84/99/6636f7097a5e461d560317024522279f52931b5a52c8caa0755a14d5f1fd/odc_loader-0.6.0-py3-none-any.whl
|
||||
|
|
@ -572,11 +572,12 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e7/c3/3031c931098de393393e1f93a38dc9ed6805d86bb801acc3cf2d5bd1e6b7/plotly-6.5.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a8/87/77cc11c7a9ea9fd05503def69e3d18605852cd0d4b0d3b8f15bbeb3ef1d1/pooch-1.8.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/cd/24/3b7a0818484df9c28172857af32c2397b6d8fcd99d9468bd4684f98ebf0a/proto_plus-1.27.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/56/13/333b8f421738f149d4fe5e49553bc2a2ab75235486259f689b4b91f96cec/protobuf-6.33.2-cp39-abi3-manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/ff/7b/e9a6fa461ef266c5a23485004934b8f08a2a8ddc447802161ea56d9837dd/psygnal-0.15.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/2d/4f/3593e5adb88a188c798604aed95fbc1479f30230e7f51e8f2c770e6a3832/psygnal-0.15.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl
|
||||
|
|
@ -587,8 +588,9 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/82/06/cad54e8ce758bd836ee5411691cbd49efeb9cc611b374670fce299519334/pyshp-3.0.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/9f/86/3ec01436c6235a23a80e978b261a87481c1acaf626a5c618e9edac30e5e1/pystac-1.14.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/5d/d2/5f6367b14c9f250d1a6725d18bd1e9584f5ab1587e292f3a847e59189598/pystac_client-0.9.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/88/ae/baf3a8057d8129896a7e02619df43ea0d918fc5b2bb66eb6e2470595fbac/python_box-7.3.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/7b/84/66c0d9cca2a09074ec2ce6fffa87709ca51b0d197ae742d835e841bac660/rasterio-1.4.4-cp313-cp313-manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/48/4a/1af9aa9810fb30668568f2c4dd3eec2412c8e9762b69201d971c509b295e/rasterio-1.5.0-cp313-cp313-manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/f2/98/7e6d147fd16a10a5f821db6e25f192265d6ecca3d82957a4fdd592cad49c/ratelim-0.1.6-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/34/83/a485250bc09db55e4b4389d99e583fac871ceeaaa4620b67a31d8db95ef5/rechunker-0.5.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/13/2f/b4530fbf948867702d0a3f27de4a6aab1d156f406d72852ab902c4d04de9/rich_rst-1.3.2-py3-none-any.whl
|
||||
|
|
@ -608,16 +610,16 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/c0/95/6b7873f0267973ebd55ba9cd33a690b35a116f2779901ef6185a0e21864d/streamlit-1.52.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/b5/fc/5e2988590ff2e0128eea6446806c904445a44e17256c67141573ea16b5a5/textual-6.11.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/84/38/47fab2a5fad163ca4851f7a20eb2442491cc63bf2756ec4ef161bc1461dd/textual-7.0.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/8d/c0/fdf9d3ee103ce66a55f0532835ad5e154226c5222423c6636ba049dc42fc/traittypes-0.2.3-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/94/fc/1d34ec891900d9337169ff9f8252fcaa633ae5c4d36b67effd849ed4f9ac/ty-0.0.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e7/c1/56ef16bf5dcd255155cc736d276efa6ae0a5c26fd685e28f0412a4013c01/types_pytz-2025.2.0.20251108-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/95/20/92e3083b0e854943015bc8a7866e284ead9efadf9bf6809e6fce3b7ded61/ultraplot-1.66.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/43/6c/b26831b890b37c09882f6406efd31441c8e512bf1efbc967b9d867c5e02b/ultraplot-1.70.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a9/99/3ae339466c9183ea5b8ae87b34c0b897eda475d2aec2307cae60e5cd4f29/uritemplate-4.2.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e6/9f/ca52771fe972e0dcc5167fedb609940e01516066938ff2ee28b273ae4f29/vega_datasets-0.9.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/04/d5/81d1403788f072e7d0e2b2fe539a0ae4410f27886ff52df094e5348c99ea/vegafusion-2.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a7/6b/48f6d47a92eaf6f0dd235146307a7eb0d179b78d2faebc53aca3f1e49177/vl_convert_python-1.8.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/6f/61/dc6f4a38cf1b8699f64c57d7f021ca42c39bfe782d8a6eaefb7e8418e925/vl_convert_python-1.9.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/3f/0e/fa3b193432cfc60c93b42f3be03365f5f909d2b3ea410295cf36df739e31/widgetsnbextension-4.0.15-py3-none-any.whl
|
||||
|
|
@ -884,10 +886,10 @@ packages:
|
|||
- pkg:pypi/argon2-cffi-bindings?source=hash-mapping
|
||||
size: 35943
|
||||
timestamp: 1762509452935
|
||||
- pypi: https://files.pythonhosted.org/packages/e0/b1/0542e0cab6f49f151a2d7a42400f84f706fc0b64e85dc1f56708b2e9fd37/array_api_compat-1.12.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/df/5d/493b1b5528ab5072feae30821ff3a07b7a0474213d548efb1fdf135f85c1/array_api_compat-1.13.0-py3-none-any.whl
|
||||
name: array-api-compat
|
||||
version: 1.12.0
|
||||
sha256: a0b4795b6944a9507fde54679f9350e2ad2b1e2acf4a2408a098cdc27f890a8b
|
||||
version: 1.13.0
|
||||
sha256: c15026a0ddec42815383f07da285472e1b1ff2e632eb7afbcfe9b08fcbad9bf1
|
||||
requires_dist:
|
||||
- cupy ; extra == 'cupy'
|
||||
- dask>=2024.9.0 ; extra == 'dask'
|
||||
|
|
@ -905,16 +907,16 @@ packages:
|
|||
- array-api-strict ; extra == 'dev'
|
||||
- dask[array]>=2024.9.0 ; extra == 'dev'
|
||||
- jax[cpu] ; extra == 'dev'
|
||||
- ndonnx ; extra == 'dev'
|
||||
- numpy>=1.22 ; extra == 'dev'
|
||||
- pytest ; extra == 'dev'
|
||||
- torch ; extra == 'dev'
|
||||
- sparse>=0.15.1 ; extra == 'dev'
|
||||
- ndonnx ; extra == 'dev'
|
||||
requires_python: '>=3.10'
|
||||
- pypi: https://files.pythonhosted.org/packages/1d/05/2709750ddb088eb2fc5053ba214b4f54334d15d4cb28217e2956b5507bac/array_api_extra-0.9.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/21/2b/bfa1cfe370dd4ed51f834f2c6ad93b7f6263b83615ab96ad91094cc98ec6/array_api_extra-0.9.2-py3-none-any.whl
|
||||
name: array-api-extra
|
||||
version: 0.9.1
|
||||
sha256: 78b3e6605d1cdc9a66bb49e340e1bb620f045f1809a4e146d74500c3cb813b74
|
||||
version: 0.9.2
|
||||
sha256: d0643a9a4e981746057649accad068ca0fe4066d890f6a95d8b4cd5131b3b661
|
||||
requires_dist:
|
||||
- array-api-compat>=1.12.0,<2
|
||||
requires_python: '>=3.10'
|
||||
|
|
@ -1026,10 +1028,10 @@ packages:
|
|||
- astropy[dev] ; extra == 'dev-all'
|
||||
- astropy[test-all] ; extra == 'dev-all'
|
||||
requires_python: '>=3.11'
|
||||
- pypi: https://files.pythonhosted.org/packages/1f/07/50501947849e780cb5580ebcd7af08c14d431640562e18a8ac2b055c90ec/astropy_iers_data-0.2025.12.22.0.40.30-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/57/61/2d06c08f022c9b617b79f6c55d88e596c1795a1d211e6bf584ac4b9e9506/astropy_iers_data-0.2026.1.5.0.43.43-py3-none-any.whl
|
||||
name: astropy-iers-data
|
||||
version: 0.2025.12.22.0.40.30
|
||||
sha256: 2fbc71988d96aa29566667c6568a2bc5ca00748174b1f8ac3e9f7b09d4c27cac
|
||||
version: 0.2026.1.5.0.43.43
|
||||
sha256: fe2c35e9abc99142083d717ea76bf7bde373dc12e502aaeced28ae4ff9bfc345
|
||||
requires_dist:
|
||||
- pytest ; extra == 'docs'
|
||||
- hypothesis ; extra == 'test'
|
||||
|
|
@ -1298,10 +1300,10 @@ packages:
|
|||
purls: []
|
||||
size: 249684
|
||||
timestamp: 1761066654684
|
||||
- pypi: https://files.pythonhosted.org/packages/3d/9e/1c90a122ea6180e8c72eb7294adc92531b0e08eb3d2324c2ba70d37f4802/azure_storage_blob-12.27.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/d8/3a/6ef2047a072e54e1142718d433d50e9514c999a58f51abfff7902f3a72f8/azure_storage_blob-12.28.0-py3-none-any.whl
|
||||
name: azure-storage-blob
|
||||
version: 12.27.1
|
||||
sha256: 65d1e25a4628b7b6acd20ff7902d8da5b4fde8e46e19c8f6d213a3abc3ece272
|
||||
version: 12.28.0
|
||||
sha256: 00fb1db28bf6a7b7ecaa48e3b1d5c83bfadacc5a678b77826081304bd87d6461
|
||||
requires_dist:
|
||||
- azure-core>=1.30.0
|
||||
- cryptography>=2.1.4
|
||||
|
|
@ -1772,16 +1774,6 @@ packages:
|
|||
- pkg:pypi/click?source=hash-mapping
|
||||
size: 97676
|
||||
timestamp: 1764518652276
|
||||
- pypi: https://files.pythonhosted.org/packages/3d/9a/2abecb28ae875e39c8cad711eb1186d8d14eab564705325e77e4e6ab9ae5/click_plugins-1.1.1.2-py2.py3-none-any.whl
|
||||
name: click-plugins
|
||||
version: 1.1.1.2
|
||||
sha256: 008d65743833ffc1f5417bf0e78e8d2c23aab04d9745ba817bd3e71b0feb6aa6
|
||||
requires_dist:
|
||||
- click>=4.0
|
||||
- pytest>=3.6 ; extra == 'dev'
|
||||
- pytest-cov ; extra == 'dev'
|
||||
- wheel ; extra == 'dev'
|
||||
- coveralls ; extra == 'dev'
|
||||
- pypi: https://files.pythonhosted.org/packages/73/86/43fa9f15c5b9fb6e82620428827cd3c284aa933431405d1bcf5231ae3d3e/cligj-0.7.2-py3-none-any.whl
|
||||
name: cligj
|
||||
version: 0.7.2
|
||||
|
|
@ -2517,10 +2509,10 @@ packages:
|
|||
- pkg:pypi/cycler?source=hash-mapping
|
||||
size: 14778
|
||||
timestamp: 1764466758386
|
||||
- pypi: https://files.pythonhosted.org/packages/8d/05/8efadba80e1296526e69c1dceba8b0f0bc3756e8d69f6ed9b0e647cf3169/cyclopts-4.4.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/20/5b/0eceb9a5990de9025733a0d212ca43649ba9facd58b8552b6bf93c11439d/cyclopts-4.4.4-py3-none-any.whl
|
||||
name: cyclopts
|
||||
version: 4.4.1
|
||||
sha256: 67500e9fde90f335fddbf9c452d2e7c4f58209dffe52e7abb1e272796a963bde
|
||||
version: 4.4.4
|
||||
sha256: 316f798fe2f2a30cb70e7140cfde2a46617bfbb575d31bbfdc0b2410a447bd83
|
||||
requires_dist:
|
||||
- attrs>=23.1.0
|
||||
- docstring-parser>=0.15,<4.0
|
||||
|
|
@ -2869,7 +2861,7 @@ packages:
|
|||
- pypi: ./
|
||||
name: entropice
|
||||
version: 0.1.0
|
||||
sha256: cb0c27d2c23c64d7533c03e380cf55c40e82a4d52a0392a829fad06a4ca93736
|
||||
sha256: c15584d2588d1f67ff2c69d6d3afa461fd1b9571d423497221dcb295dd7b1514
|
||||
requires_dist:
|
||||
- aiohttp>=3.12.11
|
||||
- bokeh>=3.7.3
|
||||
|
|
@ -2932,6 +2924,7 @@ packages:
|
|||
- ty>=0.0.2,<0.0.3
|
||||
- ruff>=0.14.9,<0.15
|
||||
- pandas-stubs>=2.3.3.251201,<3
|
||||
- pytest>=9.0.2,<10
|
||||
requires_python: '>=3.13,<3.14'
|
||||
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
||||
name: entropy
|
||||
|
|
@ -3422,17 +3415,17 @@ packages:
|
|||
requires_dist:
|
||||
- smmap>=3.0.1,<6
|
||||
requires_python: '>=3.7'
|
||||
- pypi: https://files.pythonhosted.org/packages/01/61/d4b89fec821f72385526e1b9d9a3a0385dda4a72b206d28049e2c7cd39b8/gitpython-3.1.45-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl
|
||||
name: gitpython
|
||||
version: 3.1.45
|
||||
sha256: 8908cb2e02fb3b93b7eb0f2827125cb699869470432cc885f019b8fd0fccff77
|
||||
version: 3.1.46
|
||||
sha256: 79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058
|
||||
requires_dist:
|
||||
- gitdb>=4.0.1,<5
|
||||
- typing-extensions>=3.10.0.2 ; python_full_version < '3.10'
|
||||
- coverage[toml] ; extra == 'test'
|
||||
- ddt>=1.1.1,!=1.4.3 ; extra == 'test'
|
||||
- mock ; python_full_version < '3.8' and extra == 'test'
|
||||
- mypy ; extra == 'test'
|
||||
- mypy==1.18.2 ; python_full_version >= '3.9' and extra == 'test'
|
||||
- pre-commit ; extra == 'test'
|
||||
- pytest>=7.3.1 ; extra == 'test'
|
||||
- pytest-cov ; extra == 'test'
|
||||
|
|
@ -3516,42 +3509,35 @@ packages:
|
|||
- google-api-core>=1.31.5,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0
|
||||
- uritemplate>=3.0.1,<5
|
||||
requires_python: '>=3.7'
|
||||
- pypi: https://files.pythonhosted.org/packages/c6/97/451d55e05487a5cd6279a01a7e34921858b16f7dc8aa38a2c684743cd2b3/google_auth-2.45.0-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/db/18/79e9008530b79527e0d5f79e7eef08d3b179b7f851cfd3a2f27822fbdfa9/google_auth-2.47.0-py3-none-any.whl
|
||||
name: google-auth
|
||||
version: 2.45.0
|
||||
sha256: 82344e86dc00410ef5382d99be677c6043d72e502b625aa4f4afa0bdacca0f36
|
||||
version: 2.47.0
|
||||
sha256: c516d68336bfde7cf0da26aab674a36fedcf04b37ac4edd59c597178760c3498
|
||||
requires_dist:
|
||||
- cachetools>=2.0.0,<7.0
|
||||
- pyasn1-modules>=0.2.1
|
||||
- rsa>=3.1.4,<5
|
||||
- cryptography>=38.0.3 ; extra == 'cryptography'
|
||||
- cryptography<39.0.0 ; python_full_version < '3.8' and extra == 'cryptography'
|
||||
- aiohttp>=3.6.2,<4.0.0 ; extra == 'aiohttp'
|
||||
- requests>=2.20.0,<3.0.0 ; extra == 'aiohttp'
|
||||
- cryptography ; extra == 'enterprise-cert'
|
||||
- pyopenssl ; extra == 'enterprise-cert'
|
||||
- pyopenssl>=20.0.0 ; extra == 'pyopenssl'
|
||||
- cryptography>=38.0.3 ; extra == 'pyopenssl'
|
||||
- cryptography<39.0.0 ; python_full_version < '3.8' and extra == 'pyopenssl'
|
||||
- pyjwt>=2.0 ; extra == 'pyjwt'
|
||||
- cryptography>=38.0.3 ; extra == 'pyjwt'
|
||||
- cryptography<39.0.0 ; python_full_version < '3.8' and extra == 'pyjwt'
|
||||
- pyu2f>=0.1.5 ; extra == 'reauth'
|
||||
- requests>=2.20.0,<3.0.0 ; extra == 'requests'
|
||||
- grpcio ; extra == 'testing'
|
||||
- flask ; extra == 'testing'
|
||||
- freezegun ; extra == 'testing'
|
||||
- mock ; extra == 'testing'
|
||||
- oauth2client ; extra == 'testing'
|
||||
- pyjwt>=2.0 ; extra == 'testing'
|
||||
- cryptography>=38.0.3 ; extra == 'testing'
|
||||
- cryptography<39.0.0 ; python_full_version < '3.8' and extra == 'testing'
|
||||
- pytest ; extra == 'testing'
|
||||
- pytest-cov ; extra == 'testing'
|
||||
- pytest-localserver ; extra == 'testing'
|
||||
- pyopenssl>=20.0.0 ; extra == 'testing'
|
||||
- cryptography>=38.0.3 ; extra == 'testing'
|
||||
- cryptography<39.0.0 ; python_full_version < '3.8' and extra == 'testing'
|
||||
- pyu2f>=0.1.5 ; extra == 'testing'
|
||||
- responses ; extra == 'testing'
|
||||
- urllib3 ; extra == 'testing'
|
||||
|
|
@ -3564,7 +3550,7 @@ packages:
|
|||
- aiohttp<3.10.0 ; extra == 'testing'
|
||||
- urllib3 ; extra == 'urllib3'
|
||||
- packaging ; extra == 'urllib3'
|
||||
requires_python: '>=3.7'
|
||||
requires_python: '>=3.8'
|
||||
- pypi: https://files.pythonhosted.org/packages/99/d5/3c97526c8796d3caf5f4b3bed2b05e8a7102326f00a334e7a438237f3b22/google_auth_httplib2-0.3.0-py3-none-any.whl
|
||||
name: google-auth-httplib2
|
||||
version: 0.3.0
|
||||
|
|
@ -3767,10 +3753,10 @@ packages:
|
|||
- pkg:pypi/hyperframe?source=hash-mapping
|
||||
size: 17397
|
||||
timestamp: 1737618427549
|
||||
- pypi: https://files.pythonhosted.org/packages/94/56/c5e8db63ba0e27b310a0b4c384da555b361741e7d186044d31f400c0419e/icechunk-1.1.14-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/8c/d7/db466e07a21553441adbf915f0913a3f8fecece364cacb2392f11be267be/icechunk-1.1.15-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
name: icechunk
|
||||
version: 1.1.14
|
||||
sha256: adb01a0275144c58f741b5402e658930326e86f7b389e879065e01625c021f7c
|
||||
version: 1.1.15
|
||||
sha256: c9e0cc3c8623a48861470553dbb8b0f1e86600989f597ce41ecf47568d8d099d
|
||||
requires_dist:
|
||||
- zarr>=3,!=3.0.3
|
||||
- boto3 ; extra == 'test'
|
||||
|
|
@ -3926,6 +3912,11 @@ packages:
|
|||
- pkg:pypi/importlib-metadata?source=hash-mapping
|
||||
size: 34641
|
||||
timestamp: 1747934053147
|
||||
- pypi: https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl
|
||||
name: iniconfig
|
||||
version: 2.3.0
|
||||
sha256: f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12
|
||||
requires_python: '>=3.10'
|
||||
- pypi: https://files.pythonhosted.org/packages/4c/0f/b66d63d4a5426c09005d3713b056e634e00e69788fdc88d1ffe40e5b7654/ipycytoscape-1.3.3-py2.py3-none-any.whl
|
||||
name: ipycytoscape
|
||||
version: 1.3.3
|
||||
|
|
@ -4024,10 +4015,10 @@ packages:
|
|||
- traittypes>=0.2.1,<3
|
||||
- xyzservices>=2021.8.1
|
||||
requires_python: '>=3.8'
|
||||
- pypi: https://files.pythonhosted.org/packages/f1/df/8ee1c5dd1e3308b5d5b2f2dfea323bb2f3827da8d654abb6642051199049/ipython-9.8.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/86/92/162cfaee4ccf370465c5af1ce36a9eacec1becb552f2033bb3584e6f640a/ipython-9.9.0-py3-none-any.whl
|
||||
name: ipython
|
||||
version: 9.8.0
|
||||
sha256: ebe6d1d58d7d988fbf23ff8ff6d8e1622cfdb194daf4b7b73b792c4ec3b85385
|
||||
version: 9.9.0
|
||||
sha256: b457fe9165df2b84e8ec909a97abcf2ed88f565970efba16b1f7229c283d252b
|
||||
requires_dist:
|
||||
- colorama>=0.4.4 ; sys_platform == 'win32'
|
||||
- decorator>=4.3.2
|
||||
|
|
@ -4067,7 +4058,8 @@ packages:
|
|||
- pandas>2.1 ; extra == 'test-extra'
|
||||
- trio>=0.1.0 ; extra == 'test-extra'
|
||||
- matplotlib>3.9 ; extra == 'matplotlib'
|
||||
- ipython[doc,matplotlib,test,test-extra] ; extra == 'all'
|
||||
- ipython[doc,matplotlib,terminal,test,test-extra] ; extra == 'all'
|
||||
- argcomplete>=3.0 ; extra == 'all'
|
||||
requires_python: '>=3.11'
|
||||
- pypi: https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl
|
||||
name: ipython-pygments-lexers
|
||||
|
|
@ -6859,14 +6851,15 @@ packages:
|
|||
version: 1.6.0
|
||||
sha256: 87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c
|
||||
requires_python: '>=3.5'
|
||||
- pypi: https://files.pythonhosted.org/packages/97/1a/78b19893197ed7525edfa7f124a461626541e82aec694a468ba97755c24e/netcdf4-1.7.3-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/7b/7a/a8d32501bb95ecff342004a674720164f95ad616f269450b3bc13dc88ae3/netcdf4-1.7.4-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
||||
name: netcdf4
|
||||
version: 1.7.3
|
||||
sha256: 0c764ba6f6a1421cab5496097e8a1c4d2e36be2a04880dfd288bb61b348c217e
|
||||
version: 1.7.4
|
||||
sha256: a72c9f58767779ec14cb7451c3b56bdd8fdc027a792fac2062b14e090c5617f3
|
||||
requires_dist:
|
||||
- cftime
|
||||
- certifi
|
||||
- numpy
|
||||
- numpy>=2.3.0 ; platform_machine == 'ARM64' and sys_platform == 'win32'
|
||||
- numpy>=1.21.2 ; platform_machine != 'ARM64' or sys_platform != 'win32'
|
||||
- cython ; extra == 'tests'
|
||||
- packaging ; extra == 'tests'
|
||||
- pytest ; extra == 'tests'
|
||||
|
|
@ -7045,10 +7038,10 @@ packages:
|
|||
- pkg:pypi/nvidia-ml-py?source=hash-mapping
|
||||
size: 48971
|
||||
timestamp: 1765209768013
|
||||
- pypi: https://files.pythonhosted.org/packages/4a/4e/44dbb46b3d1b0ec61afda8e84837870f2f9ace33c564317d59b70bc19d3e/nvidia_nccl_cu12-2.28.9-py3-none-manylinux_2_18_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/23/2d/609d0392d992259c6dc39881688a7fc13b1397a668bc360fbd68d1396f85/nvidia_nccl_cu12-2.29.2-py3-none-manylinux_2_18_x86_64.whl
|
||||
name: nvidia-nccl-cu12
|
||||
version: 2.28.9
|
||||
sha256: 485776daa8447da5da39681af455aa3b2c2586ddcf4af8772495e7c532c7e5ab
|
||||
version: 2.29.2
|
||||
sha256: 3a9a0bf4142126e0d0ed99ec202579bef8d007601f9fab75af60b10324666b12
|
||||
requires_python: '>=3'
|
||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/nvtx-0.2.14-py313h07c4f96_0.conda
|
||||
sha256: 9341cb332428242ab938c5fc202008c12430ec43b8b83511d327f14bf8fd6d96
|
||||
|
|
@ -7501,6 +7494,17 @@ packages:
|
|||
- xarray ; extra == 'dev-optional'
|
||||
- plotly[dev-optional] ; extra == 'dev'
|
||||
requires_python: '>=3.8'
|
||||
- pypi: https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl
|
||||
name: pluggy
|
||||
version: 1.6.0
|
||||
sha256: e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746
|
||||
requires_dist:
|
||||
- pre-commit ; extra == 'dev'
|
||||
- tox ; extra == 'dev'
|
||||
- pytest ; extra == 'testing'
|
||||
- pytest-benchmark ; extra == 'testing'
|
||||
- coverage ; extra == 'testing'
|
||||
requires_python: '>=3.9'
|
||||
- conda: https://conda.anaconda.org/conda-forge/noarch/polars-1.34.0-pyh6a1acc5_0.conda
|
||||
sha256: 7e8bb10f4373202a0be760d9ac74f92c5e7e6095251180642678a8f57f10c58a
|
||||
md5: d398dbcb3312bbebc2b2f3dbb98b4262
|
||||
|
|
@ -7652,10 +7656,10 @@ packages:
|
|||
- pkg:pypi/psutil?source=hash-mapping
|
||||
size: 501735
|
||||
timestamp: 1762092897061
|
||||
- pypi: https://files.pythonhosted.org/packages/ff/7b/e9a6fa461ef266c5a23485004934b8f08a2a8ddc447802161ea56d9837dd/psygnal-0.15.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/2d/4f/3593e5adb88a188c798604aed95fbc1479f30230e7f51e8f2c770e6a3832/psygnal-0.15.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
name: psygnal
|
||||
version: 0.15.0
|
||||
sha256: a0172efeb861280bca05673989a4df21624f44344eff20b873d8c9d0edc01350
|
||||
version: 0.15.1
|
||||
sha256: e9fca977f5335deea39aed22e31d9795983e4f243e59a7d3c4105793adb7693d
|
||||
requires_dist:
|
||||
- wrapt ; extra == 'proxy'
|
||||
- pydantic ; extra == 'pydantic'
|
||||
|
|
@ -8013,6 +8017,26 @@ packages:
|
|||
- pystac[validation]>=1.10.0
|
||||
- python-dateutil>=2.8.2
|
||||
requires_python: '>=3.10'
|
||||
- pypi: https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl
|
||||
name: pytest
|
||||
version: 9.0.2
|
||||
sha256: 711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b
|
||||
requires_dist:
|
||||
- colorama>=0.4 ; sys_platform == 'win32'
|
||||
- exceptiongroup>=1 ; python_full_version < '3.11'
|
||||
- iniconfig>=1.0.1
|
||||
- packaging>=22
|
||||
- pluggy>=1.5,<2
|
||||
- pygments>=2.7.2
|
||||
- tomli>=1 ; python_full_version < '3.11'
|
||||
- argcomplete ; extra == 'dev'
|
||||
- attrs>=19.2 ; extra == 'dev'
|
||||
- hypothesis>=3.56 ; extra == 'dev'
|
||||
- mock ; extra == 'dev'
|
||||
- requests ; extra == 'dev'
|
||||
- setuptools ; extra == 'dev'
|
||||
- xmlschema ; extra == 'dev'
|
||||
requires_python: '>=3.10'
|
||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.13.11-hc97d973_100_cp313.conda
|
||||
build_number: 100
|
||||
sha256: 9cf014cf28e93ee242bacfbf664e8b45ae06e50b04291e640abeaeb0cba0364c
|
||||
|
|
@ -8383,19 +8407,19 @@ packages:
|
|||
license: LicenseRef-Custom
|
||||
size: 6143
|
||||
timestamp: 1765438804958
|
||||
- pypi: https://files.pythonhosted.org/packages/7b/84/66c0d9cca2a09074ec2ce6fffa87709ca51b0d197ae742d835e841bac660/rasterio-1.4.4-cp313-cp313-manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/48/4a/1af9aa9810fb30668568f2c4dd3eec2412c8e9762b69201d971c509b295e/rasterio-1.5.0-cp313-cp313-manylinux_2_28_x86_64.whl
|
||||
name: rasterio
|
||||
version: 1.4.4
|
||||
sha256: c072450caa96428b1218b030500bb908fd6f09bc013a88969ff81a124b6a112a
|
||||
version: 1.5.0
|
||||
sha256: 08a7580cbb9b3bd320bdf827e10c9b2424d0df066d8eef6f2feb37e154ce0c17
|
||||
requires_dist:
|
||||
- affine
|
||||
- attrs
|
||||
- certifi
|
||||
- click>=4.0,!=8.2.*
|
||||
- cligj>=0.5
|
||||
- numpy>=1.24
|
||||
- click-plugins
|
||||
- numpy>=2
|
||||
- pyparsing
|
||||
- rasterio[docs,ipython,plot,s3,test] ; extra == 'all'
|
||||
- ghp-import ; extra == 'docs'
|
||||
- numpydoc ; extra == 'docs'
|
||||
- sphinx ; extra == 'docs'
|
||||
|
|
@ -8404,28 +8428,17 @@ packages:
|
|||
- ipython>=2.0 ; extra == 'ipython'
|
||||
- matplotlib ; extra == 'plot'
|
||||
- boto3>=1.2.4 ; extra == 's3'
|
||||
- aiohttp ; extra == 'test'
|
||||
- boto3>=1.2.4 ; extra == 'test'
|
||||
- fsspec ; extra == 'test'
|
||||
- hypothesis ; extra == 'test'
|
||||
- matplotlib ; extra == 'test'
|
||||
- packaging ; extra == 'test'
|
||||
- pytest-cov>=2.2.0 ; extra == 'test'
|
||||
- pytest>=2.8.2 ; extra == 'test'
|
||||
- requests ; extra == 'test'
|
||||
- shapely ; extra == 'test'
|
||||
- fsspec ; extra == 'all'
|
||||
- sphinx-rtd-theme ; extra == 'all'
|
||||
- ipython>=2.0 ; extra == 'all'
|
||||
- packaging ; extra == 'all'
|
||||
- ghp-import ; extra == 'all'
|
||||
- boto3>=1.2.4 ; extra == 'all'
|
||||
- matplotlib ; extra == 'all'
|
||||
- sphinx-click ; extra == 'all'
|
||||
- sphinx ; extra == 'all'
|
||||
- pytest>=2.8.2 ; extra == 'all'
|
||||
- hypothesis ; extra == 'all'
|
||||
- shapely ; extra == 'all'
|
||||
- numpydoc ; extra == 'all'
|
||||
- pytest-cov>=2.2.0 ; extra == 'all'
|
||||
requires_python: '>=3.10'
|
||||
requires_python: '>=3.12'
|
||||
- pypi: https://files.pythonhosted.org/packages/f2/98/7e6d147fd16a10a5f821db6e25f192265d6ecca3d82957a4fdd592cad49c/ratelim-0.1.6-py2.py3-none-any.whl
|
||||
name: ratelim
|
||||
version: 0.1.6
|
||||
|
|
@ -9139,10 +9152,10 @@ packages:
|
|||
- pkg:pypi/terminado?source=hash-mapping
|
||||
size: 22452
|
||||
timestamp: 1710262728753
|
||||
- pypi: https://files.pythonhosted.org/packages/b5/fc/5e2988590ff2e0128eea6446806c904445a44e17256c67141573ea16b5a5/textual-6.11.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/84/38/47fab2a5fad163ca4851f7a20eb2442491cc63bf2756ec4ef161bc1461dd/textual-7.0.1-py3-none-any.whl
|
||||
name: textual
|
||||
version: 6.11.0
|
||||
sha256: 9e663b73ed37123a9b13c16a0c85e09ef917a4cfded97814361ed5cccfa40f89
|
||||
version: 7.0.1
|
||||
sha256: f9b7d16fa9b640bfff2a2008bf31e3f2d4429dc85e07a9583be033840ed15174
|
||||
requires_dist:
|
||||
- markdown-it-py[linkify]>=2.1.0
|
||||
- mdit-py-plugins
|
||||
|
|
@ -9429,10 +9442,10 @@ packages:
|
|||
license: BSD-3-Clause
|
||||
size: 508347
|
||||
timestamp: 1765407086135
|
||||
- pypi: https://files.pythonhosted.org/packages/95/20/92e3083b0e854943015bc8a7866e284ead9efadf9bf6809e6fce3b7ded61/ultraplot-1.66.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/43/6c/b26831b890b37c09882f6406efd31441c8e512bf1efbc967b9d867c5e02b/ultraplot-1.70.0-py3-none-any.whl
|
||||
name: ultraplot
|
||||
version: 1.66.0
|
||||
sha256: 87fecb897ca5c7d54b76ac81e5b8635be45d9c9d42d629469f1d283e6405f9e1
|
||||
version: 1.70.0
|
||||
sha256: 2b29d1b1e36bd6cf88458370825cfab2c62b9acab706a2cfa434660d7dc4bf74
|
||||
requires_dist:
|
||||
- numpy>=1.26.0
|
||||
- matplotlib>=3.9,<3.11
|
||||
|
|
@ -9496,10 +9509,10 @@ packages:
|
|||
- packaging
|
||||
- narwhals>=1.42
|
||||
requires_python: '>=3.9'
|
||||
- pypi: https://files.pythonhosted.org/packages/a7/6b/48f6d47a92eaf6f0dd235146307a7eb0d179b78d2faebc53aca3f1e49177/vl_convert_python-1.8.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/6f/61/dc6f4a38cf1b8699f64c57d7f021ca42c39bfe782d8a6eaefb7e8418e925/vl_convert_python-1.9.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
name: vl-convert-python
|
||||
version: 1.8.0
|
||||
sha256: b51264998e8fcc43dbce801484a950cfe6513cdc4c46b20604ef50989855a617
|
||||
version: 1.9.0
|
||||
sha256: 849e6773a7e05d58ab215386b1065e7713f4846b9ac6b0d743bb3e1b20337231
|
||||
requires_python: '>=3.7'
|
||||
- pypi: https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl
|
||||
name: watchdog
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ dependencies = [
|
|||
"pypalettes>=0.2.1,<0.3",
|
||||
"ty>=0.0.2,<0.0.3",
|
||||
"ruff>=0.14.9,<0.15",
|
||||
"pandas-stubs>=2.3.3.251201,<3",
|
||||
"pandas-stubs>=2.3.3.251201,<3", "pytest>=9.0.2,<10",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
195
scripts/recalculate_test_metrics.py
Normal file
195
scripts/recalculate_test_metrics.py
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
#!/usr/bin/env python
|
||||
"""Recalculate test metrics and confusion matrix for existing training results.
|
||||
|
||||
This script loads previously trained models and recalculates test metrics
|
||||
and confusion matrices for training runs that were completed before these
|
||||
outputs were added to the training pipeline.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import cupy as cp
|
||||
import numpy as np
|
||||
import toml
|
||||
import torch
|
||||
import xarray as xr
|
||||
from sklearn import set_config
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.utils.paths import RESULTS_DIR
|
||||
|
||||
# Enable array_api_dispatch to handle CuPy/NumPy namespace properly
|
||||
set_config(array_api_dispatch=True)
|
||||
|
||||
|
||||
def recalculate_metrics(results_dir: Path):
|
||||
"""Recalculate test metrics and confusion matrix for a training result.
|
||||
|
||||
Args:
|
||||
results_dir: Path to the results directory containing the trained model.
|
||||
|
||||
"""
|
||||
print(f"\nProcessing: {results_dir}")
|
||||
|
||||
# Load the search settings to get training configuration
|
||||
settings_file = results_dir / "search_settings.toml"
|
||||
if not settings_file.exists():
|
||||
print(" ❌ Missing search_settings.toml, skipping...")
|
||||
return
|
||||
|
||||
with open(settings_file) as f:
|
||||
config = toml.load(f)
|
||||
settings = config["settings"]
|
||||
|
||||
# Check if metrics already exist
|
||||
test_metrics_file = results_dir / "test_metrics.toml"
|
||||
cm_file = results_dir / "confusion_matrix.nc"
|
||||
|
||||
# if test_metrics_file.exists() and cm_file.exists():
|
||||
# print(" ✓ Metrics already exist, skipping...")
|
||||
# return
|
||||
|
||||
# Load the best estimator
|
||||
best_model_file = results_dir / "best_estimator_model.pkl"
|
||||
if not best_model_file.exists():
|
||||
print(" ❌ Missing best_estimator_model.pkl, skipping...")
|
||||
return
|
||||
|
||||
print(f" Loading best estimator from {best_model_file.name}...")
|
||||
with open(best_model_file, "rb") as f:
|
||||
best_estimator = pickle.load(f)
|
||||
|
||||
# Recreate the dataset ensemble
|
||||
print(" Recreating training dataset...")
|
||||
dataset_ensemble = DatasetEnsemble(
|
||||
grid=settings["grid"],
|
||||
level=settings["level"],
|
||||
target=settings["target"],
|
||||
members=settings.get(
|
||||
"members",
|
||||
[
|
||||
"AlphaEarth",
|
||||
"ArcticDEM",
|
||||
"ERA5-yearly",
|
||||
"ERA5-seasonal",
|
||||
"ERA5-shoulder",
|
||||
],
|
||||
),
|
||||
dimension_filters=settings.get("dimension_filters", {}),
|
||||
variable_filters=settings.get("variable_filters", {}),
|
||||
filter_target=settings.get("filter_target", False),
|
||||
add_lonlat=settings.get("add_lonlat", True),
|
||||
)
|
||||
|
||||
task = settings["task"]
|
||||
model = settings["model"]
|
||||
device = "torch" if model in ["espa"] else "cuda"
|
||||
|
||||
# Create training data
|
||||
training_data = dataset_ensemble.create_cat_training_dataset(task=task, device=device)
|
||||
|
||||
# Prepare test data - match training.py's approach
|
||||
print(" Preparing test data...")
|
||||
# For XGBoost with CuPy arrays, convert y_test to CPU (same as training.py)
|
||||
y_test = (
|
||||
training_data.y.test.get()
|
||||
if model == "xgboost" and isinstance(training_data.y.test, cp.ndarray)
|
||||
else training_data.y.test
|
||||
)
|
||||
|
||||
# Compute predictions on the test set (use original device data)
|
||||
print(" Computing predictions on test set...")
|
||||
y_pred = best_estimator.predict(training_data.X.test)
|
||||
|
||||
# Use torch
|
||||
y_pred = torch.as_tensor(y_pred, device="cuda")
|
||||
y_test = torch.as_tensor(y_test, device="cuda")
|
||||
|
||||
# Compute metrics manually to avoid device issues
|
||||
print(" Computing test metrics...")
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
f1_score,
|
||||
jaccard_score,
|
||||
precision_score,
|
||||
recall_score,
|
||||
)
|
||||
|
||||
test_metrics = {}
|
||||
if task == "binary":
|
||||
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
|
||||
test_metrics["recall"] = float(recall_score(y_test, y_pred))
|
||||
test_metrics["precision"] = float(precision_score(y_test, y_pred))
|
||||
test_metrics["f1"] = float(f1_score(y_test, y_pred))
|
||||
test_metrics["jaccard"] = float(jaccard_score(y_test, y_pred))
|
||||
else:
|
||||
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
|
||||
test_metrics["f1_macro"] = float(f1_score(y_test, y_pred, average="macro"))
|
||||
test_metrics["f1_weighted"] = float(f1_score(y_test, y_pred, average="weighted"))
|
||||
test_metrics["precision_macro"] = float(precision_score(y_test, y_pred, average="macro", zero_division=0))
|
||||
test_metrics["precision_weighted"] = float(precision_score(y_test, y_pred, average="weighted", zero_division=0))
|
||||
test_metrics["recall_macro"] = float(recall_score(y_test, y_pred, average="macro"))
|
||||
test_metrics["jaccard_micro"] = float(jaccard_score(y_test, y_pred, average="micro"))
|
||||
test_metrics["jaccard_macro"] = float(jaccard_score(y_test, y_pred, average="macro"))
|
||||
test_metrics["jaccard_weighted"] = float(jaccard_score(y_test, y_pred, average="weighted"))
|
||||
|
||||
# Get confusion matrix
|
||||
print(" Computing confusion matrix...")
|
||||
labels = list(range(len(training_data.y.labels)))
|
||||
labels = torch.as_tensor(np.array(labels), device="cuda")
|
||||
print(" Device of y_test:", getattr(training_data.y.test, "device", "cpu"))
|
||||
print(" Device of y_pred:", getattr(y_pred, "device", "cpu"))
|
||||
print(" Device of labels:", getattr(labels, "device", "cpu"))
|
||||
cm = confusion_matrix(y_test, y_pred, labels=labels)
|
||||
cm = cm.cpu().numpy()
|
||||
labels = labels.cpu().numpy()
|
||||
label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))]
|
||||
cm_xr = xr.DataArray(
|
||||
cm,
|
||||
dims=["true_label", "predicted_label"],
|
||||
coords={"true_label": label_names, "predicted_label": label_names},
|
||||
name="confusion_matrix",
|
||||
)
|
||||
|
||||
# Store the test metrics
|
||||
if not test_metrics_file.exists():
|
||||
print(f" Storing test metrics to {test_metrics_file.name}...")
|
||||
with open(test_metrics_file, "w") as f:
|
||||
toml.dump({"test_metrics": test_metrics}, f)
|
||||
else:
|
||||
print(" ✓ Test metrics already exist")
|
||||
|
||||
# Store the confusion matrix
|
||||
if True:
|
||||
# if not cm_file.exists():
|
||||
print(f" Storing confusion matrix to {cm_file.name}...")
|
||||
cm_xr.to_netcdf(cm_file, engine="h5netcdf")
|
||||
else:
|
||||
print(" ✓ Confusion matrix already exists")
|
||||
|
||||
print(" ✓ Done!")
|
||||
|
||||
|
||||
def main():
|
||||
"""Find all training results and recalculate metrics for those missing them."""
|
||||
print("Searching for training results directories...")
|
||||
|
||||
# Find all results directories
|
||||
results_dirs = sorted([d for d in RESULTS_DIR.glob("*") if d.is_dir()])
|
||||
|
||||
print(f"Found {len(results_dirs)} results directories.\n")
|
||||
|
||||
for results_dir in results_dirs:
|
||||
recalculate_metrics(results_dir)
|
||||
# try:
|
||||
# except Exception as e:
|
||||
# print(f" ❌ Error processing {results_dir.name}: {e}")
|
||||
# continue
|
||||
|
||||
print("\n✅ All done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
58
scripts/rechunk_zarr.py
Normal file
58
scripts/rechunk_zarr.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
import xarray as xr
|
||||
import zarr
|
||||
from rich import print
|
||||
import dask.distributed as dd
|
||||
|
||||
from entropice.utils.paths import get_era5_stores
|
||||
import entropice.utils.codecs
|
||||
|
||||
def print_info(daily_raw = None, show_vars: bool = True):
|
||||
if daily_raw is None:
|
||||
daily_store = get_era5_stores("daily")
|
||||
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
||||
print("=== Daily INFO ===")
|
||||
print(f" Dims: {daily_raw.sizes}")
|
||||
numchunks = 1
|
||||
chunksizes = {}
|
||||
approxchunksize = 4 # 4 Bytes = float32
|
||||
for d, cs in daily_raw.chunksizes.items():
|
||||
numchunks *= len(cs)
|
||||
chunksizes[d] = max(cs)
|
||||
approxchunksize *= max(cs)
|
||||
approxchunksize /= 10e6 # MB
|
||||
print(f" Chunks: {chunksizes} (~{approxchunksize:.2f}MB) => {numchunks} total")
|
||||
print(f" Encoding: {daily_raw.encoding}")
|
||||
if show_vars:
|
||||
print(" Variables:")
|
||||
for var in daily_raw.data_vars:
|
||||
da = daily_raw[var]
|
||||
print(f" {var} Encoding:")
|
||||
print(da.encoding)
|
||||
print("")
|
||||
|
||||
def rechunk():
|
||||
daily_store = get_era5_stores("daily")
|
||||
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
||||
print_info(daily_raw, False)
|
||||
daily_raw = daily_raw.chunk({
|
||||
"time": 120,
|
||||
"latitude": -1, # Should be 337,
|
||||
"longitude": -1 # Should be 3600
|
||||
})
|
||||
print_info(daily_raw, False)
|
||||
|
||||
encoding = entropice.utils.codecs.from_ds(daily_raw)
|
||||
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked")
|
||||
daily_raw.to_zarr(daily_store_rechunked, mode="w", encoding=encoding, consolidated=False)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with (
|
||||
dd.LocalCluster(n_workers=1, threads_per_worker=10, memory_limit="100GB") as cluster,
|
||||
dd.Client(cluster) as client,
|
||||
):
|
||||
print(client)
|
||||
print(client.dashboard_link)
|
||||
rechunk()
|
||||
print("Done.")
|
||||
144
scripts/rerun_missing_inference.py
Normal file
144
scripts/rerun_missing_inference.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
#!/usr/bin/env python
|
||||
"""Rerun inference for training results that are missing predicted probabilities.
|
||||
|
||||
This script searches through training result directories and identifies those that have
|
||||
a trained model but are missing inference results. It then loads the model and dataset
|
||||
configuration, reruns inference, and saves the results.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import toml
|
||||
from rich.console import Console
|
||||
from rich.progress import track
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.ml.inference import predict_proba
|
||||
from entropice.utils.paths import RESULTS_DIR
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def find_incomplete_trainings() -> list[Path]:
|
||||
"""Find training result directories missing inference results.
|
||||
|
||||
Returns:
|
||||
list[Path]: List of directories with trained models but missing predictions.
|
||||
|
||||
"""
|
||||
incomplete = []
|
||||
|
||||
if not RESULTS_DIR.exists():
|
||||
console.print(f"[yellow]Results directory not found: {RESULTS_DIR}[/yellow]")
|
||||
return incomplete
|
||||
|
||||
# Search for all training result directories
|
||||
for result_dir in RESULTS_DIR.glob("*_cv*"):
|
||||
if not result_dir.is_dir():
|
||||
continue
|
||||
|
||||
model_file = result_dir / "best_estimator_model.pkl"
|
||||
settings_file = result_dir / "search_settings.toml"
|
||||
predictions_file = result_dir / "predicted_probabilities.parquet"
|
||||
|
||||
# Check if model and settings exist but predictions are missing
|
||||
if model_file.exists() and settings_file.exists() and not predictions_file.exists():
|
||||
incomplete.append(result_dir)
|
||||
|
||||
return incomplete
|
||||
|
||||
|
||||
def rerun_inference(result_dir: Path) -> bool:
|
||||
"""Rerun inference for a training result directory.
|
||||
|
||||
Args:
|
||||
result_dir (Path): Path to the training result directory.
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise.
|
||||
|
||||
"""
|
||||
try:
|
||||
console.print(f"\n[cyan]Processing: {result_dir.name}[/cyan]")
|
||||
|
||||
# Load settings
|
||||
settings_file = result_dir / "search_settings.toml"
|
||||
with open(settings_file) as f:
|
||||
settings_data = toml.load(f)
|
||||
|
||||
settings = settings_data["settings"]
|
||||
|
||||
# Reconstruct DatasetEnsemble from settings
|
||||
ensemble = DatasetEnsemble(
|
||||
grid=settings["grid"],
|
||||
level=settings["level"],
|
||||
target=settings["target"],
|
||||
members=settings["members"],
|
||||
dimension_filters=settings.get("dimension_filters", {}),
|
||||
variable_filters=settings.get("variable_filters", {}),
|
||||
filter_target=settings.get("filter_target", False),
|
||||
add_lonlat=settings.get("add_lonlat", True),
|
||||
)
|
||||
|
||||
# Load trained model
|
||||
model_file = result_dir / "best_estimator_model.pkl"
|
||||
with open(model_file, "rb") as f:
|
||||
clf = pickle.load(f)
|
||||
|
||||
console.print("[green]✓[/green] Loaded model and settings")
|
||||
|
||||
# Get class labels
|
||||
classes = settings["classes"]
|
||||
|
||||
# Run inference
|
||||
console.print("[yellow]Running inference...[/yellow]")
|
||||
preds = predict_proba(ensemble, clf=clf, classes=classes)
|
||||
|
||||
# Save predictions
|
||||
preds_file = result_dir / "predicted_probabilities.parquet"
|
||||
preds.to_parquet(preds_file)
|
||||
|
||||
console.print(f"[green]✓[/green] Saved {len(preds)} predictions to {preds_file.name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error processing {result_dir.name}: {e}[/red]")
|
||||
import traceback
|
||||
|
||||
console.print(f"[red]{traceback.format_exc()}[/red]")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Rerun missing inferences for incomplete training results."""
|
||||
console.print("[bold blue]Searching for incomplete training results...[/bold blue]")
|
||||
|
||||
incomplete_dirs = find_incomplete_trainings()
|
||||
|
||||
if not incomplete_dirs:
|
||||
console.print("[green]No incomplete trainings found. All trainings have predictions![/green]")
|
||||
return
|
||||
|
||||
console.print(f"[yellow]Found {len(incomplete_dirs)} training(s) missing predictions:[/yellow]")
|
||||
for d in incomplete_dirs:
|
||||
console.print(f" • {d.name}")
|
||||
|
||||
console.print(f"\n[bold]Processing {len(incomplete_dirs)} training result(s)...[/bold]\n")
|
||||
|
||||
successful = 0
|
||||
failed = 0
|
||||
|
||||
for result_dir in track(incomplete_dirs, description="Rerunning inference"):
|
||||
if rerun_inference(result_dir):
|
||||
successful += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
console.print("\n[bold]Summary:[/bold]")
|
||||
console.print(f" [green]Successful: {successful}[/green]")
|
||||
console.print(f" [red]Failed: {failed}[/red]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -10,9 +10,11 @@ import pandas as pd
|
|||
import pydeck as pdk
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.utils.class_ordering import get_ordered_classes
|
||||
from entropice.dashboard.utils.colors import get_cmap, get_palette
|
||||
from entropice.dashboard.utils.geometry import fix_hex_geometry
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.ml.training import TrainingSettings
|
||||
|
||||
|
||||
def render_performance_summary(results: pd.DataFrame, refit_metric: str):
|
||||
|
|
@ -125,7 +127,7 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
|
|||
)
|
||||
|
||||
|
||||
def render_parameter_distributions(results: pd.DataFrame, settings: dict | None = None):
|
||||
def render_parameter_distributions(results: pd.DataFrame, settings: TrainingSettings | None = None):
|
||||
"""Render histograms of parameter distributions explored.
|
||||
|
||||
Args:
|
||||
|
|
@ -1152,15 +1154,18 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1
|
|||
|
||||
|
||||
@st.fragment
|
||||
def render_confusion_matrix_map(result_path: Path, settings: dict):
|
||||
"""Render 3D pydeck map showing confusion matrix results (TP, FP, TN, FN).
|
||||
def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
||||
"""Render 3D pydeck map showing prediction results.
|
||||
|
||||
Uses true labels for elevation (height) and different shades of red for incorrect predictions
|
||||
based on the predicted class.
|
||||
|
||||
Args:
|
||||
result_path: Path to the training result directory.
|
||||
settings: Settings dictionary containing grid, level, task, and target information.
|
||||
|
||||
"""
|
||||
st.subheader("🗺️ Confusion Matrix Spatial Distribution")
|
||||
st.subheader("🗺️ Prediction Results Map")
|
||||
|
||||
# Load predicted probabilities
|
||||
preds_file = result_path / "predicted_probabilities.parquet"
|
||||
|
|
@ -1190,62 +1195,41 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
|
|||
st.error(f"Error loading training data: {e}")
|
||||
return
|
||||
|
||||
# Get the labeled cells (those with true labels)
|
||||
labeled_cells = training_data.dataset[training_data.dataset.index.isin(training_data.y.binned.index)]
|
||||
# Get all cells from the complete dataset (not just test split)
|
||||
# Use the full dataset which includes both train and test splits
|
||||
all_cells = training_data.dataset.copy()
|
||||
|
||||
# Merge predictions with true labels
|
||||
# Reset index to avoid ambiguity between index and column
|
||||
labeled_gdf = labeled_cells.copy()
|
||||
labeled_gdf = labeled_gdf.reset_index().rename(columns={"index": "cell_id"})
|
||||
labeled_gdf["true_class"] = training_data.y.binned.loc[labeled_cells.index].to_numpy()
|
||||
labeled_gdf = all_cells.reset_index().rename(columns={"index": "cell_id"})
|
||||
labeled_gdf["true_class"] = training_data.y.binned.loc[all_cells.index].to_numpy()
|
||||
|
||||
# Merge with predictions - ensure we keep GeoDataFrame type
|
||||
merged_df = labeled_gdf.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="inner")
|
||||
# Merge with predictions - use left join to keep all cells
|
||||
merged_df = labeled_gdf.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="left")
|
||||
merged = gpd.GeoDataFrame(merged_df, geometry="geometry", crs=labeled_gdf.crs)
|
||||
|
||||
# Mark which cells have predictions (test split) vs not (training split)
|
||||
merged["in_test_split"] = merged["predicted_class"].notna()
|
||||
|
||||
# For cells without predictions (training split), use true class as predicted class for visualization
|
||||
merged["predicted_class"] = merged["predicted_class"].fillna(merged["true_class"])
|
||||
|
||||
if len(merged) == 0:
|
||||
st.warning("No matching predictions found for labeled cells.")
|
||||
return
|
||||
|
||||
# Determine confusion matrix category
|
||||
def get_confusion_category(row):
|
||||
true_label = row["true_class"]
|
||||
pred_label = row["predicted_class"]
|
||||
# Mark correct vs incorrect predictions (only meaningful for test split)
|
||||
merged["is_correct"] = merged["true_class"] == merged["predicted_class"]
|
||||
|
||||
if task == "binary":
|
||||
# For binary classification
|
||||
if true_label == "RTS" and pred_label == "RTS":
|
||||
return "True Positive"
|
||||
elif true_label == "RTS" and pred_label == "No-RTS":
|
||||
return "False Negative"
|
||||
elif true_label == "No-RTS" and pred_label == "RTS":
|
||||
return "False Positive"
|
||||
else: # true_label == "No-RTS" and pred_label == "No-RTS"
|
||||
return "True Negative"
|
||||
else:
|
||||
# For multiclass (count/density)
|
||||
if true_label == pred_label:
|
||||
return "Correct"
|
||||
else:
|
||||
return "Incorrect"
|
||||
|
||||
merged["confusion_category"] = merged.apply(get_confusion_category, axis=1)
|
||||
# Get ordered class labels for the task
|
||||
ordered_classes = get_ordered_classes(task)
|
||||
|
||||
# Create controls
|
||||
col1, col2 = st.columns([3, 1])
|
||||
col1, col2, col3 = st.columns([2, 1, 1])
|
||||
|
||||
with col1:
|
||||
# Filter by confusion category
|
||||
if task == "binary":
|
||||
categories = [
|
||||
"All",
|
||||
"True Positive",
|
||||
"False Positive",
|
||||
"True Negative",
|
||||
"False Negative",
|
||||
]
|
||||
else:
|
||||
categories = ["All", "Correct", "Incorrect"]
|
||||
# Filter by prediction correctness and split
|
||||
categories = ["All", "Test Split Only", "Training Split Only", "Correct (Test)", "Incorrect (Test)"]
|
||||
|
||||
selected_category = st.selectbox(
|
||||
"Filter by Category",
|
||||
|
|
@ -1263,10 +1247,26 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
|
|||
key="confusion_map_opacity",
|
||||
)
|
||||
|
||||
with col3:
|
||||
line_width = st.slider(
|
||||
"Line Width",
|
||||
min_value=0.5,
|
||||
max_value=3.0,
|
||||
value=1.0,
|
||||
step=0.5,
|
||||
key="confusion_map_line_width",
|
||||
)
|
||||
|
||||
# Filter data if needed
|
||||
if selected_category != "All":
|
||||
display_gdf = merged[merged["confusion_category"] == selected_category].copy()
|
||||
else:
|
||||
if selected_category == "Test Split Only":
|
||||
display_gdf = merged[merged["in_test_split"]].copy()
|
||||
elif selected_category == "Training Split Only":
|
||||
display_gdf = merged[~merged["in_test_split"]].copy()
|
||||
elif selected_category == "Correct (Test)":
|
||||
display_gdf = merged[merged["is_correct"] & merged["in_test_split"]].copy()
|
||||
elif selected_category == "Incorrect (Test)":
|
||||
display_gdf = merged[~merged["is_correct"] & merged["in_test_split"]].copy()
|
||||
else: # "All"
|
||||
display_gdf = merged.copy()
|
||||
|
||||
if len(display_gdf) == 0:
|
||||
|
|
@ -1280,49 +1280,72 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
|
|||
if grid == "hex":
|
||||
display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(fix_hex_geometry)
|
||||
|
||||
# Assign colors based on confusion category
|
||||
if task == "binary":
|
||||
color_map = {
|
||||
"True Positive": [46, 204, 113], # Green
|
||||
"False Positive": [231, 76, 60], # Red
|
||||
"True Negative": [52, 152, 219], # Blue
|
||||
"False Negative": [241, 196, 15], # Yellow
|
||||
}
|
||||
else:
|
||||
color_map = {
|
||||
"Correct": [46, 204, 113], # Green
|
||||
"Incorrect": [231, 76, 60], # Red
|
||||
}
|
||||
# Get red material colormap for incorrect predictions
|
||||
red_cmap = get_cmap("red_predictions") # Use red_material palette
|
||||
n_classes = len(ordered_classes)
|
||||
|
||||
display_gdf_wgs84["fill_color"] = display_gdf_wgs84["confusion_category"].map(color_map)
|
||||
# Assign colors based on correctness
|
||||
def get_color(row):
|
||||
if row["is_correct"]:
|
||||
# Green for correct predictions
|
||||
return [46, 204, 113]
|
||||
else:
|
||||
# Different shades of red for each predicted class (ordered)
|
||||
pred_class = row["predicted_class"]
|
||||
if pred_class in ordered_classes:
|
||||
class_idx = ordered_classes.index(pred_class)
|
||||
# Sample from red colormap based on class index
|
||||
color_value = red_cmap(class_idx / max(n_classes - 1, 1))
|
||||
return [int(color_value[0] * 255), int(color_value[1] * 255), int(color_value[2] * 255)]
|
||||
else:
|
||||
# Fallback red if class not found
|
||||
return [231, 76, 60]
|
||||
|
||||
# Add elevation based on confusion category (higher for errors)
|
||||
if task == "binary":
|
||||
elevation_map = {
|
||||
"True Positive": 0.8,
|
||||
"False Positive": 1.0,
|
||||
"True Negative": 0.3,
|
||||
"False Negative": 1.0,
|
||||
}
|
||||
else:
|
||||
elevation_map = {
|
||||
"Correct": 0.5,
|
||||
"Incorrect": 1.0,
|
||||
}
|
||||
display_gdf_wgs84["fill_color"] = display_gdf_wgs84.apply(get_color, axis=1)
|
||||
|
||||
display_gdf_wgs84["elevation"] = display_gdf_wgs84["confusion_category"].map(elevation_map)
|
||||
# Add line color based on split: blue for test split, orange for training split
|
||||
def get_line_color(row):
|
||||
if row["in_test_split"]:
|
||||
return [52, 152, 219] # Blue for test split
|
||||
else:
|
||||
return [230, 126, 34] # Orange for training split
|
||||
|
||||
display_gdf_wgs84["line_color"] = display_gdf_wgs84.apply(get_line_color, axis=1)
|
||||
|
||||
# Add elevation based on TRUE label (not predicted)
|
||||
# Map each true class to a height based on its position in the ordered list
|
||||
def get_elevation(row):
|
||||
true_class = row["true_class"]
|
||||
if true_class in ordered_classes:
|
||||
class_idx = ordered_classes.index(true_class)
|
||||
# Normalize to 0-1 range based on class position
|
||||
return (class_idx + 1) / n_classes
|
||||
else:
|
||||
return 0.5 # Default elevation
|
||||
|
||||
display_gdf_wgs84["elevation"] = display_gdf_wgs84.apply(get_elevation, axis=1)
|
||||
|
||||
# Convert to GeoJSON format
|
||||
geojson_data = []
|
||||
for _, row in display_gdf_wgs84.iterrows():
|
||||
# Determine split and status for tooltip
|
||||
split_name = "Test Split" if row["in_test_split"] else "Training Split"
|
||||
if row["in_test_split"]:
|
||||
status = "✓ Correct" if row["is_correct"] else "✗ Incorrect"
|
||||
else:
|
||||
status = "(No prediction - training data)"
|
||||
|
||||
feature = {
|
||||
"type": "Feature",
|
||||
"geometry": row["geometry"].__geo_interface__,
|
||||
"properties": {
|
||||
"true_class": str(row["true_class"]),
|
||||
"predicted_class": str(row["predicted_class"]),
|
||||
"confusion_category": str(row["confusion_category"]),
|
||||
"predicted_class": str(row["predicted_class"]) if row["in_test_split"] else "N/A",
|
||||
"is_correct": bool(row["is_correct"]),
|
||||
"split": split_name,
|
||||
"status": status,
|
||||
"fill_color": row["fill_color"],
|
||||
"line_color": row["line_color"],
|
||||
"elevation": float(row["elevation"]),
|
||||
},
|
||||
}
|
||||
|
|
@ -1338,8 +1361,8 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
|
|||
extruded=True,
|
||||
wireframe=False,
|
||||
get_fill_color="properties.fill_color",
|
||||
get_line_color=[80, 80, 80],
|
||||
line_width_min_pixels=0.5,
|
||||
get_line_color="properties.line_color",
|
||||
line_width_min_pixels=line_width,
|
||||
get_elevation="properties.elevation",
|
||||
elevation_scale=500000,
|
||||
pickable=True,
|
||||
|
|
@ -1353,9 +1376,10 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
|
|||
layers=[layer],
|
||||
initial_view_state=view_state,
|
||||
tooltip={
|
||||
"html": "<b>True Label:</b> {true_class}<br/>"
|
||||
"html": "<b>Split:</b> {split}<br/>"
|
||||
"<b>True Label:</b> {true_class}<br/>"
|
||||
"<b>Predicted Label:</b> {predicted_class}<br/>"
|
||||
"<b>Category:</b> {confusion_category}",
|
||||
"<b>Status:</b> {status}",
|
||||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||
},
|
||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||
|
|
@ -1365,56 +1389,240 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
|
|||
st.pydeck_chart(deck)
|
||||
|
||||
# Show statistics
|
||||
col1, col2, col3 = st.columns(3)
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
st.metric("Total Labeled Cells", len(merged))
|
||||
|
||||
if task == "binary":
|
||||
with col2:
|
||||
tp = len(merged[merged["confusion_category"] == "True Positive"])
|
||||
fp = len(merged[merged["confusion_category"] == "False Positive"])
|
||||
tn = len(merged[merged["confusion_category"] == "True Negative"])
|
||||
fn = len(merged[merged["confusion_category"] == "False Negative"])
|
||||
with col2:
|
||||
test_count = len(merged[merged["in_test_split"]])
|
||||
st.metric("Test Split", test_count)
|
||||
|
||||
accuracy = (tp + tn) / len(merged) if len(merged) > 0 else 0
|
||||
st.metric("Accuracy", f"{accuracy:.2%}")
|
||||
with col3:
|
||||
train_count = len(merged[~merged["in_test_split"]])
|
||||
st.metric("Training Split", train_count)
|
||||
|
||||
with col3:
|
||||
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||
st.metric("F1 Score", f"{f1:.3f}")
|
||||
|
||||
# Show confusion matrix counts
|
||||
st.caption(f"TP: {tp} | FP: {fp} | TN: {tn} | FN: {fn}")
|
||||
else:
|
||||
with col2:
|
||||
correct = len(merged[merged["confusion_category"] == "Correct"])
|
||||
accuracy = correct / len(merged) if len(merged) > 0 else 0
|
||||
st.metric("Accuracy", f"{accuracy:.2%}")
|
||||
|
||||
with col3:
|
||||
incorrect = len(merged[merged["confusion_category"] == "Incorrect"])
|
||||
st.metric("Incorrect", incorrect)
|
||||
with col4:
|
||||
test_cells = merged[merged["in_test_split"]]
|
||||
if len(test_cells) > 0:
|
||||
correct = len(test_cells[test_cells["is_correct"]])
|
||||
accuracy = correct / len(test_cells)
|
||||
st.metric("Test Accuracy", f"{accuracy:.2%}")
|
||||
else:
|
||||
st.metric("Test Accuracy", "N/A")
|
||||
|
||||
# Add legend
|
||||
with st.expander("Legend", expanded=True):
|
||||
st.markdown("**Confusion Matrix Categories:**")
|
||||
# Split indicators (border colors)
|
||||
st.markdown("**Data Split (Border Color):**")
|
||||
|
||||
for category, color in color_map.items():
|
||||
count = len(merged[merged["confusion_category"] == category])
|
||||
percentage = count / len(merged) * 100 if len(merged) > 0 else 0
|
||||
test_count = len(merged[merged["in_test_split"]])
|
||||
train_count = len(merged[~merged["in_test_split"]])
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
||||
f'<div style="width: 20px; height: 20px; background-color: rgb({color[0]}, {color[1]}, {color[2]}); '
|
||||
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
||||
f"<span>{category}: {count} ({percentage:.1f}%)</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
||||
f'<div style="width: 20px; height: 20px; background-color: #333; '
|
||||
f'border: 2px solid rgb(52, 152, 219); margin-right: 8px; flex-shrink: 0;"></div>'
|
||||
f"<span><b>Test Split</b> ({test_count} cells, {test_count / len(merged) * 100:.1f}%)</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center; margin-bottom: 12px;">'
|
||||
f'<div style="width: 20px; height: 20px; background-color: #333; '
|
||||
f'border: 2px solid rgb(230, 126, 34); margin-right: 8px; flex-shrink: 0;"></div>'
|
||||
f"<span><b>Training Split</b> ({train_count} cells, {train_count / len(merged) * 100:.1f}%)</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
st.markdown("---")
|
||||
st.markdown("**Fill Color (Prediction Results):**")
|
||||
|
||||
# Correct predictions (test split only)
|
||||
test_cells = merged[merged["in_test_split"]]
|
||||
correct = len(test_cells[test_cells["is_correct"]]) if len(test_cells) > 0 else 0
|
||||
incorrect = len(test_cells[~test_cells["is_correct"]]) if len(test_cells) > 0 else 0
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center; margin-bottom: 8px;">'
|
||||
f'<div style="width: 20px; height: 20px; background-color: rgb(46, 204, 113); '
|
||||
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
||||
f"<span><b>Correct Predictions (Test)</b> ({correct} cells, {correct / len(test_cells) * 100 if len(test_cells) > 0 else 0:.1f}%)</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
# Incorrect predictions by predicted class (shades of red)
|
||||
st.markdown(
|
||||
f"<b>Incorrect Predictions by Predicted Class (Test)</b> ({incorrect} cells):", unsafe_allow_html=True
|
||||
)
|
||||
|
||||
for class_idx, class_label in enumerate(ordered_classes):
|
||||
# Get count of incorrect predictions for this predicted class (test split only)
|
||||
count = len(test_cells[(~test_cells["is_correct"]) & (test_cells["predicted_class"] == class_label)])
|
||||
if count > 0:
|
||||
# Get color for this predicted class
|
||||
color_value = red_cmap(class_idx / max(n_classes - 1, 1))
|
||||
rgb = [int(color_value[0] * 255), int(color_value[1] * 255), int(color_value[2] * 255)]
|
||||
|
||||
percentage = count / incorrect * 100 if incorrect > 0 else 0
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center; margin-bottom: 4px; margin-left: 20px;">'
|
||||
f'<div style="width: 20px; height: 20px; background-color: rgb({rgb[0]}, {rgb[1]}, {rgb[2]}); '
|
||||
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
||||
f"<span>Predicted as <i>{class_label}</i>: {count} ({percentage:.1f}%)</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
# Note about training split
|
||||
st.markdown(
|
||||
f'<div style="margin-top: 8px; font-style: italic; color: #888;">'
|
||||
f"Note: Training split cells ({train_count}) are shown with their true labels (green fill) "
|
||||
f"since predictions are only available for the test split.</div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
st.markdown("---")
|
||||
st.markdown("**Elevation (3D):**")
|
||||
st.markdown("Height represents prediction confidence: Errors are elevated higher than correct predictions.")
|
||||
|
||||
# Show elevation mapping for each true class
|
||||
st.markdown("Height represents the <b>true label</b>:", unsafe_allow_html=True)
|
||||
for class_idx, class_label in enumerate(ordered_classes):
|
||||
elevation_value = (class_idx + 1) / n_classes
|
||||
height_km = elevation_value * 500 # Since elevation_scale is 500000
|
||||
st.markdown(
|
||||
f'<div style="margin-left: 20px; margin-bottom: 2px;">• <i>{class_label}</i>: {height_km:.0f} km</div>',
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
st.info("💡 Rotate the map by holding Ctrl/Cmd and dragging.")
|
||||
|
||||
|
||||
def render_confusion_matrix_heatmap(confusion_matrix: "xr.DataArray", task: str):
|
||||
"""Render confusion matrix as an interactive heatmap.
|
||||
|
||||
Args:
|
||||
confusion_matrix: xarray DataArray with dimensions (true_label, predicted_label).
|
||||
task: Task type ('binary' or 'multiclass').
|
||||
|
||||
"""
|
||||
import plotly.express as px
|
||||
|
||||
# Convert to DataFrame for plotting
|
||||
cm_df = confusion_matrix.to_pandas()
|
||||
|
||||
# Get labels (convert numeric labels to semantic labels if possible)
|
||||
true_labels = confusion_matrix.coords["true_label"].values
|
||||
pred_labels = confusion_matrix.coords["predicted_label"].values
|
||||
|
||||
# For binary classification, map 0/1 to No-RTS/RTS
|
||||
if task == "binary":
|
||||
label_map = {0: "No-RTS", 1: "RTS"}
|
||||
true_labels_str = [label_map.get(int(label), str(label)) for label in true_labels]
|
||||
pred_labels_str = [label_map.get(int(label), str(label)) for label in pred_labels]
|
||||
else:
|
||||
# For multiclass, use numeric labels as is
|
||||
true_labels_str = [str(label) for label in true_labels]
|
||||
pred_labels_str = [str(label) for label in pred_labels]
|
||||
|
||||
# Rename DataFrame indices and columns for display
|
||||
cm_df.index = true_labels_str
|
||||
cm_df.columns = pred_labels_str
|
||||
|
||||
# Store raw counts for annotations
|
||||
cm_counts = cm_df.copy()
|
||||
|
||||
# Normalize by row (true label) to get percentages
|
||||
cm_normalized = cm_df.div(cm_df.sum(axis=1), axis=0)
|
||||
|
||||
# Create custom text annotations showing both percentage and count
|
||||
text_annotations = []
|
||||
for i, true_label in enumerate(true_labels_str):
|
||||
row_annotations = []
|
||||
for j, pred_label in enumerate(pred_labels_str):
|
||||
count = int(cm_counts.iloc[i, j])
|
||||
percentage = cm_normalized.iloc[i, j] * 100
|
||||
row_annotations.append(f"{percentage:.1f}%<br>({count:,})")
|
||||
text_annotations.append(row_annotations)
|
||||
|
||||
# Create heatmap with normalized values
|
||||
fig = px.imshow(
|
||||
cm_normalized,
|
||||
labels=dict(x="Predicted Label", y="True Label", color="Proportion"),
|
||||
x=pred_labels_str,
|
||||
y=true_labels_str,
|
||||
color_continuous_scale="Blues",
|
||||
aspect="auto",
|
||||
zmin=0,
|
||||
zmax=1,
|
||||
)
|
||||
|
||||
# Update with custom annotations
|
||||
fig.update_traces(
|
||||
text=text_annotations,
|
||||
texttemplate="%{text}",
|
||||
textfont={"size": 12},
|
||||
)
|
||||
|
||||
# Update layout for better readability
|
||||
fig.update_layout(
|
||||
title="Confusion Matrix (Normalized by True Label)",
|
||||
xaxis_title="Predicted Label",
|
||||
yaxis_title="True Label",
|
||||
height=500,
|
||||
)
|
||||
|
||||
# Update colorbar to show percentage
|
||||
fig.update_coloraxes(
|
||||
colorbar=dict(
|
||||
title="Proportion",
|
||||
tickformat=".0%",
|
||||
)
|
||||
)
|
||||
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
st.caption(
|
||||
"📊 Values show **row-normalized percentages** (percentage of each true class predicted as each label). "
|
||||
"Raw counts shown in parentheses."
|
||||
)
|
||||
|
||||
# Calculate and display metrics from confusion matrix
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
total_samples = int(cm_df.values.sum())
|
||||
correct_predictions = int(np.trace(cm_df.values))
|
||||
accuracy = correct_predictions / total_samples if total_samples > 0 else 0
|
||||
|
||||
with col1:
|
||||
st.metric("Total Samples", f"{total_samples:,}")
|
||||
|
||||
with col2:
|
||||
st.metric("Correct Predictions", f"{correct_predictions:,}")
|
||||
|
||||
with col3:
|
||||
st.metric("Accuracy", f"{accuracy:.2%}")
|
||||
|
||||
# Add detailed breakdown for binary classification
|
||||
if task == "binary":
|
||||
st.markdown("#### Binary Classification Metrics")
|
||||
|
||||
# Extract TP, TN, FP, FN from confusion matrix
|
||||
# Assuming 0=No-RTS (negative), 1=RTS (positive)
|
||||
tn = int(cm_df.iloc[0, 0])
|
||||
fp = int(cm_df.iloc[0, 1])
|
||||
fn = int(cm_df.iloc[1, 0])
|
||||
tp = int(cm_df.iloc[1, 1])
|
||||
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
st.metric("True Positives (TP)", f"{tp:,}")
|
||||
|
||||
with col2:
|
||||
st.metric("True Negatives (TN)", f"{tn:,}")
|
||||
|
||||
with col3:
|
||||
st.metric("False Positives (FP)", f"{fp:,}")
|
||||
|
||||
with col4:
|
||||
st.metric("False Negatives (FN)", f"{fn:,}")
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import plotly.graph_objects as go
|
|||
import pydeck as pdk
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.utils.class_ordering import get_ordered_classes, sort_class_series
|
||||
from entropice.dashboard.utils.colors import get_palette
|
||||
from entropice.dashboard.utils.geometry import fix_hex_geometry
|
||||
from entropice.dashboard.utils.loaders import TrainingResult
|
||||
|
|
@ -64,8 +65,9 @@ def render_class_distribution_histogram(predictions_gdf: gpd.GeoDataFrame, task:
|
|||
"""
|
||||
st.subheader("📊 Predicted Class Distribution")
|
||||
|
||||
# Get class counts
|
||||
class_counts = predictions_gdf["predicted_class"].value_counts().sort_index()
|
||||
# Get class counts and order them properly
|
||||
class_counts = predictions_gdf["predicted_class"].value_counts()
|
||||
class_counts = sort_class_series(class_counts, task)
|
||||
|
||||
# Get colors based on task
|
||||
categories = class_counts.index.tolist()
|
||||
|
|
@ -121,7 +123,7 @@ def render_spatial_distribution_stats(predictions_gdf: gpd.GeoDataFrame):
|
|||
st.subheader("🌍 Spatial Coverage")
|
||||
|
||||
# Calculate spatial extent
|
||||
bounds = predictions_gdf.total_bounds
|
||||
bounds = predictions_gdf.to_crs("EPSG:4326").total_bounds
|
||||
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
|
|
@ -193,8 +195,8 @@ def render_inference_map(result: TrainingResult):
|
|||
col1, col2, col3 = st.columns([2, 2, 1])
|
||||
|
||||
with col1:
|
||||
# Get unique classes for filtering
|
||||
all_classes = sorted(preds_gdf["predicted_class"].unique())
|
||||
# Get unique classes for filtering (properly ordered)
|
||||
all_classes = get_ordered_classes(task, preds_gdf["predicted_class"].unique().tolist())
|
||||
filter_options = ["All Classes", *all_classes]
|
||||
|
||||
selected_filter = st.selectbox(
|
||||
|
|
@ -240,11 +242,13 @@ def render_inference_map(result: TrainingResult):
|
|||
if grid == "hex":
|
||||
display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(fix_hex_geometry)
|
||||
|
||||
# Assign colors based on predicted class
|
||||
colors_palette = get_palette(task, len(all_classes))
|
||||
# Assign colors based on predicted class (using canonical ordering)
|
||||
# Get all possible classes for this task to ensure consistent colors
|
||||
canonical_classes = get_ordered_classes(task)
|
||||
colors_palette = get_palette(task, len(canonical_classes))
|
||||
|
||||
# Create color mapping for all classes
|
||||
color_map = {cls: colors_palette[i] for i, cls in enumerate(all_classes)}
|
||||
# Create color mapping for canonical classes
|
||||
color_map = {cls: colors_palette[i] for i, cls in enumerate(canonical_classes)}
|
||||
|
||||
# Convert hex colors to RGB
|
||||
def hex_to_rgb(hex_color):
|
||||
|
|
@ -349,8 +353,9 @@ def render_class_comparison(predictions_gdf: gpd.GeoDataFrame, task: str):
|
|||
"""
|
||||
st.subheader("🔍 Class Comparison")
|
||||
|
||||
# Get class distribution
|
||||
# Get class distribution and order properly
|
||||
class_counts = predictions_gdf["predicted_class"].value_counts()
|
||||
class_counts = sort_class_series(class_counts, task)
|
||||
|
||||
if len(class_counts) < 2:
|
||||
st.info("Need at least 2 classes for comparison.")
|
||||
|
|
@ -362,7 +367,13 @@ def render_class_comparison(predictions_gdf: gpd.GeoDataFrame, task: str):
|
|||
with col1:
|
||||
st.markdown("**Class Proportions")
|
||||
|
||||
colors = get_palette(task, len(class_counts))
|
||||
# Get colors matching canonical order
|
||||
canonical_classes = get_ordered_classes(task)
|
||||
all_colors = get_palette(task, len(canonical_classes))
|
||||
color_map = {cls: all_colors[i] for i, cls in enumerate(canonical_classes)}
|
||||
|
||||
# Extract colors for available classes in proper order
|
||||
colors = [color_map[cls] for cls in class_counts.index]
|
||||
|
||||
fig = go.Figure(
|
||||
data=[
|
||||
|
|
|
|||
|
|
@ -271,7 +271,7 @@ def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart:
|
|||
return chart
|
||||
|
||||
|
||||
def plot_era5_time_heatmap(era5_array: xr.DataArray) -> alt.Chart:
|
||||
def plot_era5_time_heatmap(era5_array: xr.DataArray) -> alt.Chart | None:
|
||||
"""Create a heatmap showing ERA5 feature weights by variable and year (averaging over season).
|
||||
|
||||
This is specifically for seasonal/shoulder data to show temporal trends.
|
||||
|
|
|
|||
|
|
@ -48,50 +48,78 @@ def create_sample_count_heatmap(
|
|||
|
||||
def create_sample_count_bar_chart(
|
||||
sample_df: pd.DataFrame,
|
||||
task_colors: list[str] | None = None,
|
||||
target_color_maps: dict[str, list[str]] | None = None,
|
||||
) -> go.Figure:
|
||||
"""Create bar chart showing sample counts by grid, target, and task.
|
||||
|
||||
Args:
|
||||
sample_df: DataFrame with columns: Grid, Target, Task, Samples (Coverage).
|
||||
task_colors: Optional color palette for tasks. If None, uses default Plotly colors.
|
||||
target_color_maps: Optional dictionary mapping target names ("rts", "mllabels") to color palettes.
|
||||
If None, uses default Plotly colors.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the bar chart visualization.
|
||||
|
||||
"""
|
||||
fig = px.bar(
|
||||
sample_df,
|
||||
x="Grid",
|
||||
y="Samples (Coverage)",
|
||||
color="Task",
|
||||
facet_col="Target",
|
||||
barmode="group",
|
||||
title="Sample Counts by Grid Configuration and Target Dataset",
|
||||
labels={
|
||||
"Grid": "Grid Configuration",
|
||||
"Samples (Coverage)": "Number of Samples",
|
||||
},
|
||||
color_discrete_sequence=task_colors,
|
||||
height=500,
|
||||
# Create subplots manually to have better control over colors
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
targets = sorted(sample_df["Target"].unique())
|
||||
tasks = sorted(sample_df["Task"].unique())
|
||||
|
||||
fig = make_subplots(
|
||||
rows=1,
|
||||
cols=len(targets),
|
||||
subplot_titles=[f"Target: {target}" for target in targets],
|
||||
shared_yaxes=True,
|
||||
)
|
||||
|
||||
for col_idx, target in enumerate(targets, 1):
|
||||
target_data = sample_df[sample_df["Target"] == target]
|
||||
# Get color palette for this target
|
||||
colors = target_color_maps.get(target, None) if target_color_maps else None
|
||||
|
||||
for task_idx, task in enumerate(tasks):
|
||||
task_data = target_data[target_data["Task"] == task]
|
||||
color = colors[task_idx] if colors and task_idx < len(colors) else None
|
||||
|
||||
# Create a unique legendgroup per task so colors are consistent
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=task_data["Grid"],
|
||||
y=task_data["Samples (Coverage)"],
|
||||
name=task,
|
||||
marker_color=color,
|
||||
legendgroup=task, # Group by task name
|
||||
showlegend=(col_idx == 1), # Only show legend for first subplot
|
||||
),
|
||||
row=1,
|
||||
col=col_idx,
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title_text="Training Sample Counts by Grid Configuration and Target Dataset",
|
||||
barmode="group",
|
||||
height=500,
|
||||
showlegend=True,
|
||||
)
|
||||
|
||||
# Update facet labels to be cleaner
|
||||
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
|
||||
fig.update_xaxes(tickangle=-45)
|
||||
fig.update_yaxes(title_text="Number of Samples", row=1, col=1)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_feature_count_stacked_bar(
|
||||
breakdown_df: pd.DataFrame,
|
||||
source_colors: list[str] | None = None,
|
||||
source_color_map: dict[str, str] | None = None,
|
||||
) -> go.Figure:
|
||||
"""Create stacked bar chart showing feature counts by data source.
|
||||
|
||||
Args:
|
||||
breakdown_df: DataFrame with columns: Grid, Data Source, Number of Features.
|
||||
source_colors: Optional color palette for data sources. If None, uses default Plotly colors.
|
||||
source_color_map: Optional dictionary mapping data source names to specific colors.
|
||||
If None, uses default Plotly colors.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the stacked bar chart visualization.
|
||||
|
|
@ -103,12 +131,12 @@ def create_feature_count_stacked_bar(
|
|||
y="Number of Features",
|
||||
color="Data Source",
|
||||
barmode="stack",
|
||||
title="Total Features by Data Source Across Grid Configurations",
|
||||
title="Input Features by Data Source Across Grid Configurations",
|
||||
labels={
|
||||
"Grid": "Grid Configuration",
|
||||
"Number of Features": "Number of Features",
|
||||
},
|
||||
color_discrete_sequence=source_colors,
|
||||
color_discrete_map=source_color_map,
|
||||
text_auto=False,
|
||||
)
|
||||
|
||||
|
|
@ -136,7 +164,7 @@ def create_inference_cells_bar(
|
|||
x="Grid",
|
||||
y="Inference Cells",
|
||||
color="Grid",
|
||||
title="Inference Cells by Grid Configuration",
|
||||
title="Spatial Coverage (Grid Cells with Complete Data)",
|
||||
labels={
|
||||
"Grid": "Grid Configuration",
|
||||
"Inference Cells": "Number of Cells",
|
||||
|
|
@ -170,7 +198,7 @@ def create_total_samples_bar(
|
|||
x="Grid",
|
||||
y="Total Samples",
|
||||
color="Grid",
|
||||
title="Total Samples by Grid Configuration",
|
||||
title="Training Samples (Binary Task)",
|
||||
labels={
|
||||
"Grid": "Grid Configuration",
|
||||
"Total Samples": "Number of Samples",
|
||||
|
|
@ -188,14 +216,15 @@ def create_total_samples_bar(
|
|||
def create_feature_breakdown_donut(
|
||||
grid_data: pd.DataFrame,
|
||||
grid_config: str,
|
||||
source_colors: list[str] | None = None,
|
||||
source_color_map: dict[str, str] | None = None,
|
||||
) -> go.Figure:
|
||||
"""Create donut chart for feature breakdown by data source for a specific grid.
|
||||
|
||||
Args:
|
||||
grid_data: DataFrame with columns: Data Source, Number of Features.
|
||||
grid_config: Grid configuration name for the title.
|
||||
source_colors: Optional color palette for data sources. If None, uses default Plotly colors.
|
||||
source_color_map: Optional dictionary mapping data source names to specific colors.
|
||||
If None, uses default Plotly colors.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the donut chart visualization.
|
||||
|
|
@ -207,7 +236,8 @@ def create_feature_breakdown_donut(
|
|||
values="Number of Features",
|
||||
title=grid_config,
|
||||
hole=0.4,
|
||||
color_discrete_sequence=source_colors,
|
||||
color_discrete_map=source_color_map,
|
||||
color="Data Source",
|
||||
)
|
||||
|
||||
fig.update_traces(textposition="inside", textinfo="percent")
|
||||
|
|
@ -218,13 +248,14 @@ def create_feature_breakdown_donut(
|
|||
|
||||
def create_feature_distribution_pie(
|
||||
breakdown_df: pd.DataFrame,
|
||||
source_colors: list[str] | None = None,
|
||||
source_color_map: dict[str, str] | None = None,
|
||||
) -> go.Figure:
|
||||
"""Create pie chart for feature distribution by data source.
|
||||
|
||||
Args:
|
||||
breakdown_df: DataFrame with columns: Data Source, Number of Features.
|
||||
source_colors: Optional color palette for data sources. If None, uses default Plotly colors.
|
||||
source_color_map: Optional dictionary mapping data source names to specific colors.
|
||||
If None, uses default Plotly colors.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the pie chart visualization.
|
||||
|
|
@ -236,7 +267,8 @@ def create_feature_distribution_pie(
|
|||
values="Number of Features",
|
||||
title="Feature Distribution by Data Source",
|
||||
hole=0.4,
|
||||
color_discrete_sequence=source_colors,
|
||||
color_discrete_map=source_color_map,
|
||||
color="Data Source",
|
||||
)
|
||||
|
||||
fig.update_traces(textposition="inside", textinfo="percent+label")
|
||||
|
|
|
|||
70
src/entropice/dashboard/utils/class_ordering.py
Normal file
70
src/entropice/dashboard/utils/class_ordering.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
"""Utilities for ordering predicted classes consistently across visualizations.
|
||||
|
||||
This module leverages the canonical class labels defined in the ML dataset module
|
||||
to ensure consistent ordering across all visualizations.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from entropice.utils.types import Task
|
||||
|
||||
# Canonical orderings imported from the ML pipeline
|
||||
# Binary labels are defined inline in dataset.py: {False: "No RTS", True: "RTS"}
|
||||
# Count/Density labels are defined in the bin_values function
|
||||
BINARY_LABELS = ["No RTS", "RTS"]
|
||||
COUNT_LABELS = ["None", "Very Few", "Few", "Several", "Many", "Very Many"]
|
||||
DENSITY_LABELS = ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"]
|
||||
|
||||
CLASS_ORDERINGS: dict[Task | str, list[str]] = {
|
||||
"binary": BINARY_LABELS,
|
||||
"count": COUNT_LABELS,
|
||||
"density": DENSITY_LABELS,
|
||||
}
|
||||
|
||||
|
||||
def get_ordered_classes(task: Task | str, available_classes: list[str] | None = None) -> list[str]:
|
||||
"""Get properly ordered class labels for a given task.
|
||||
|
||||
This uses the same canonical ordering as defined in the ML dataset module,
|
||||
ensuring consistency between training and inference visualizations.
|
||||
|
||||
Args:
|
||||
task: Task type ('binary', 'count', 'density').
|
||||
available_classes: Optional list of available classes to filter and order.
|
||||
If None, returns all canonical classes for the task.
|
||||
|
||||
Returns:
|
||||
List of class labels in proper order.
|
||||
|
||||
Examples:
|
||||
>>> get_ordered_classes("binary")
|
||||
['No RTS', 'RTS']
|
||||
>>> get_ordered_classes("count", ["None", "Few", "Several"])
|
||||
['None', 'Few', 'Several']
|
||||
|
||||
"""
|
||||
canonical_order = CLASS_ORDERINGS[task]
|
||||
|
||||
if available_classes is None:
|
||||
return canonical_order
|
||||
|
||||
# Filter canonical order to only include available classes, preserving order
|
||||
return [cls for cls in canonical_order if cls in available_classes]
|
||||
|
||||
|
||||
def sort_class_series(series: pd.Series, task: Task | str) -> pd.Series:
|
||||
"""Sort a pandas Series with class labels according to canonical ordering.
|
||||
|
||||
Args:
|
||||
series: Pandas Series with class labels as index.
|
||||
task: Task type ('binary', 'count', 'density').
|
||||
|
||||
Returns:
|
||||
Sorted Series with classes in canonical order.
|
||||
|
||||
"""
|
||||
available_classes = series.index.tolist()
|
||||
ordered_classes = get_ordered_classes(task, available_classes)
|
||||
|
||||
# Reindex to get proper order
|
||||
return series.reindex(ordered_classes)
|
||||
|
|
@ -1,13 +1,9 @@
|
|||
"""Geometry utilities for dashboard visualizations."""
|
||||
|
||||
import antimeridian
|
||||
import streamlit as st
|
||||
from shapely.geometry import shape
|
||||
|
||||
try:
|
||||
import antimeridian
|
||||
except ImportError:
|
||||
antimeridian = None
|
||||
|
||||
|
||||
def fix_hex_geometry(geom):
|
||||
"""Fix hexagon geometry crossing the antimeridian.
|
||||
|
|
@ -23,13 +19,9 @@ def fix_hex_geometry(geom):
|
|||
Fixed geometry object with antimeridian issues resolved.
|
||||
|
||||
Note:
|
||||
If the antimeridian library is not available or an error occurs,
|
||||
returns the original geometry unchanged.
|
||||
If an error occurs, returns the original geometry unchanged.
|
||||
|
||||
"""
|
||||
if antimeridian is None:
|
||||
return geom
|
||||
|
||||
try:
|
||||
return shape(antimeridian.fix_shape(geom))
|
||||
except ValueError as e:
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@ class TrainingResult:
|
|||
path: Path
|
||||
settings: TrainingSettings
|
||||
results: pd.DataFrame
|
||||
metrics: dict[str, float]
|
||||
confusion_matrix: xr.DataArray
|
||||
created_at: float
|
||||
available_metrics: list[str]
|
||||
|
||||
|
|
@ -44,23 +46,32 @@ class TrainingResult:
|
|||
result_file = result_path / "search_results.parquet"
|
||||
preds_file = result_path / "predicted_probabilities.parquet"
|
||||
settings_file = result_path / "search_settings.toml"
|
||||
metrics_file = result_path / "test_metrics.toml"
|
||||
confusion_matrix_file = result_path / "confusion_matrix.nc"
|
||||
if not result_file.exists():
|
||||
raise FileNotFoundError(f"Missing results file in {result_path}")
|
||||
if not settings_file.exists():
|
||||
raise FileNotFoundError(f"Missing settings file in {result_path}")
|
||||
if not preds_file.exists():
|
||||
raise FileNotFoundError(f"Missing predictions file in {result_path}")
|
||||
if not metrics_file.exists():
|
||||
raise FileNotFoundError(f"Missing metrics file in {result_path}")
|
||||
if not confusion_matrix_file.exists():
|
||||
raise FileNotFoundError(f"Missing confusion matrix file in {result_path}")
|
||||
|
||||
created_at = result_path.stat().st_ctime
|
||||
settings = TrainingSettings(**(toml.load(settings_file)["settings"]))
|
||||
results = pd.read_parquet(result_file)
|
||||
|
||||
metrics = toml.load(metrics_file)["test_metrics"]
|
||||
confusion_matrix = xr.open_dataarray(confusion_matrix_file, engine="h5netcdf")
|
||||
available_metrics = [col.replace("mean_test_", "") for col in results.columns if col.startswith("mean_test_")]
|
||||
|
||||
return cls(
|
||||
path=result_path,
|
||||
settings=settings,
|
||||
results=results,
|
||||
metrics=metrics,
|
||||
confusion_matrix=confusion_matrix,
|
||||
created_at=created_at,
|
||||
available_metrics=available_metrics,
|
||||
)
|
||||
|
|
@ -126,7 +137,7 @@ def load_all_training_results() -> list[TrainingResult]:
|
|||
try:
|
||||
training_result = TrainingResult.from_path(result_path)
|
||||
except FileNotFoundError as e:
|
||||
st.warning(f"Skipping incomplete training result at {result_path}: {e}")
|
||||
st.warning(f"Skipping incomplete training result: {e}")
|
||||
continue
|
||||
training_results.append(training_result)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from entropice.dashboard.plots.inference import (
|
|||
from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult:
|
||||
"""Render sidebar for training run selection.
|
||||
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@ from entropice.dashboard.plots.overview import (
|
|||
create_feature_distribution_pie,
|
||||
create_inference_cells_bar,
|
||||
create_sample_count_bar_chart,
|
||||
create_sample_count_heatmap,
|
||||
create_total_samples_bar,
|
||||
)
|
||||
from entropice.dashboard.utils.colors import get_palette
|
||||
from entropice.dashboard.utils.loaders import load_all_training_results
|
||||
|
|
@ -32,11 +30,9 @@ from entropice.utils.types import (
|
|||
|
||||
def render_sample_count_overview():
|
||||
"""Render overview of sample counts per task+target+grid+level combination."""
|
||||
st.subheader("📊 Sample Counts by Configuration")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
This visualization shows the number of available samples for each combination of:
|
||||
This visualization shows the number of available training samples for each combination of:
|
||||
- **Task**: binary, count, density
|
||||
- **Target Dataset**: darts_rts, darts_mllabels
|
||||
- **Grid System**: hex, healpix
|
||||
|
|
@ -47,68 +43,39 @@ def render_sample_count_overview():
|
|||
# Get sample count DataFrame from cache
|
||||
all_stats = load_all_default_dataset_statistics()
|
||||
sample_df = DatasetStatistics.get_sample_count_df(all_stats)
|
||||
target_datasets = ["darts_rts", "darts_mllabels"]
|
||||
|
||||
# Create tabs for different views
|
||||
tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"])
|
||||
# Get color palettes for each target dataset
|
||||
n_tasks = sample_df["Task"].nunique()
|
||||
target_color_maps = {
|
||||
"rts": get_palette("task_types", n_colors=n_tasks),
|
||||
"mllabels": get_palette("data_sources", n_colors=n_tasks),
|
||||
}
|
||||
|
||||
with tab1:
|
||||
st.markdown("### Sample Counts Heatmap")
|
||||
st.markdown("Showing counts of samples with coverage")
|
||||
# Create and display bar chart
|
||||
fig = create_sample_count_bar_chart(sample_df, target_color_maps=target_color_maps)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
# Create heatmap for each target dataset
|
||||
for target in target_datasets:
|
||||
target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
|
||||
# Display full table with formatting
|
||||
st.markdown("#### Detailed Sample Counts")
|
||||
display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy()
|
||||
|
||||
# Pivot for heatmap: Grid x Task
|
||||
pivot_df = target_df.pivot_table(
|
||||
index="Grid",
|
||||
columns="Task",
|
||||
values="Samples (Coverage)",
|
||||
aggfunc="mean",
|
||||
)
|
||||
# Format numbers with commas
|
||||
display_df["Samples (Coverage)"] = display_df["Samples (Coverage)"].apply(lambda x: f"{x:,}")
|
||||
# Format coverage as percentage with 2 decimal places
|
||||
display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
|
||||
|
||||
# Sort index by grid type and level
|
||||
sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
|
||||
pivot_df = pivot_df.reindex(sort_order.sort_values("Grid_Level_Sort").index)
|
||||
|
||||
# Get color palette for sample counts
|
||||
sample_colors = get_palette(f"sample_counts_{target}", n_colors=10)
|
||||
|
||||
# Create and display heatmap
|
||||
fig = create_sample_count_heatmap(pivot_df, target, colorscale=sample_colors)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
with tab2:
|
||||
st.markdown("### Sample Counts Bar Chart")
|
||||
st.markdown("Showing counts of samples with coverage")
|
||||
|
||||
# Get color palette for tasks
|
||||
n_tasks = sample_df["Task"].nunique()
|
||||
task_colors = get_palette("task_types", n_colors=n_tasks)
|
||||
|
||||
# Create and display bar chart
|
||||
fig = create_sample_count_bar_chart(sample_df, task_colors=task_colors)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
with tab3:
|
||||
st.markdown("### Detailed Sample Counts")
|
||||
|
||||
# Display full table with formatting
|
||||
display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy()
|
||||
|
||||
# Format numbers with commas
|
||||
display_df["Samples (Coverage)"] = display_df["Samples (Coverage)"].apply(lambda x: f"{x:,}")
|
||||
# Format coverage as percentage with 2 decimal places
|
||||
display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
|
||||
|
||||
st.dataframe(display_df, hide_index=True, width="stretch")
|
||||
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
||||
|
||||
|
||||
def render_feature_count_comparison():
|
||||
"""Render static comparison of feature counts across all grid configurations."""
|
||||
st.markdown("### Feature Count Comparison Across Grid Configurations")
|
||||
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
|
||||
st.markdown(
|
||||
"""
|
||||
Comparing dataset characteristics for all grid configurations with all data sources enabled.
|
||||
- **Features**: Total number of input features from all data sources
|
||||
- **Spatial Coverage**: Number of grid cells with complete data coverage
|
||||
"""
|
||||
)
|
||||
|
||||
# Get data from cache
|
||||
all_stats = load_all_default_dataset_statistics()
|
||||
|
|
@ -116,87 +83,44 @@ def render_feature_count_comparison():
|
|||
breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
|
||||
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
|
||||
|
||||
# Create tabs for different comparison views
|
||||
comp_tab1, comp_tab2, comp_tab3 = st.tabs(["📊 Bar Chart", "📈 Breakdown", "📋 Data Table"])
|
||||
# Get all unique data sources and create color map
|
||||
unique_sources = sorted(breakdown_df["Data Source"].unique())
|
||||
n_sources = len(unique_sources)
|
||||
source_color_list = get_palette("data_sources", n_colors=n_sources)
|
||||
source_color_map = dict(zip(unique_sources, source_color_list))
|
||||
|
||||
with comp_tab1:
|
||||
st.markdown("#### Total Features by Grid Configuration")
|
||||
# Create and display stacked bar chart
|
||||
fig = create_feature_count_stacked_bar(breakdown_df, source_color_map=source_color_map)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
# Get color palette for data sources
|
||||
unique_sources = breakdown_df["Data Source"].unique()
|
||||
n_sources = len(unique_sources)
|
||||
source_colors = get_palette("data_sources", n_colors=n_sources)
|
||||
# Add spatial coverage metric
|
||||
n_grids = len(comparison_df)
|
||||
grid_colors = get_palette("grid_configs", n_colors=n_grids)
|
||||
|
||||
# Create and display stacked bar chart
|
||||
fig = create_feature_count_stacked_bar(breakdown_df, source_colors=source_colors)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
fig_cells = create_inference_cells_bar(comparison_df, grid_colors=grid_colors)
|
||||
st.plotly_chart(fig_cells, use_container_width=True)
|
||||
|
||||
# Add secondary metrics
|
||||
col1, col2 = st.columns(2)
|
||||
# Get color palette for grid configs
|
||||
n_grids = len(comparison_df)
|
||||
grid_colors = get_palette("grid_configs", n_colors=n_grids)
|
||||
# Display full comparison table with formatting
|
||||
st.markdown("#### Detailed Comparison Table")
|
||||
display_df = comparison_df[
|
||||
[
|
||||
"Grid",
|
||||
"Total Features",
|
||||
"Data Sources",
|
||||
"Inference Cells",
|
||||
]
|
||||
].copy()
|
||||
|
||||
with col1:
|
||||
fig_cells = create_inference_cells_bar(comparison_df, grid_colors=grid_colors)
|
||||
st.plotly_chart(fig_cells, width="stretch")
|
||||
# Format numbers with commas
|
||||
for col in ["Total Features", "Inference Cells"]:
|
||||
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
|
||||
|
||||
with col2:
|
||||
fig_samples = create_total_samples_bar(comparison_df, grid_colors=grid_colors)
|
||||
st.plotly_chart(fig_samples, width="stretch")
|
||||
|
||||
with comp_tab2:
|
||||
st.markdown("#### Feature Breakdown by Data Source")
|
||||
st.markdown("Showing percentage contribution of each data source across all grid configurations")
|
||||
|
||||
# Get color palette for data sources
|
||||
unique_sources = breakdown_df["Data Source"].unique()
|
||||
n_sources = len(unique_sources)
|
||||
source_colors = get_palette("data_sources", n_colors=n_sources)
|
||||
|
||||
# Create donut charts for each grid configuration
|
||||
# Organize in a grid layout
|
||||
num_grids = len(comparison_df)
|
||||
cols_per_row = 3
|
||||
num_rows = (num_grids + cols_per_row - 1) // cols_per_row
|
||||
|
||||
for row_idx in range(num_rows):
|
||||
cols = st.columns(cols_per_row)
|
||||
for col_idx in range(cols_per_row):
|
||||
grid_idx = row_idx * cols_per_row + col_idx
|
||||
if grid_idx < num_grids:
|
||||
grid_config = comparison_df.iloc[grid_idx]["Grid"]
|
||||
grid_data = breakdown_df[breakdown_df["Grid"] == grid_config]
|
||||
|
||||
with cols[col_idx]:
|
||||
fig = create_feature_breakdown_donut(grid_data, grid_config, source_colors=source_colors)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
with comp_tab3:
|
||||
st.markdown("#### Detailed Feature Count Comparison")
|
||||
|
||||
# Display full comparison table with formatting
|
||||
display_df = comparison_df[
|
||||
[
|
||||
"Grid",
|
||||
"Total Features",
|
||||
"Data Sources",
|
||||
"Inference Cells",
|
||||
"Total Samples",
|
||||
]
|
||||
].copy()
|
||||
|
||||
# Format numbers with commas
|
||||
for col in ["Total Features", "Inference Cells", "Total Samples"]:
|
||||
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
|
||||
|
||||
st.dataframe(display_df, hide_index=True, width="stretch")
|
||||
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_feature_count_explorer():
|
||||
"""Render interactive detailed configuration explorer using fragments."""
|
||||
st.markdown("### Detailed Configuration Explorer")
|
||||
st.markdown("Select specific grid configuration and data sources for detailed statistics")
|
||||
|
||||
# Grid selection
|
||||
|
|
@ -250,8 +174,6 @@ def render_feature_count_explorer():
|
|||
|
||||
# Show results if at least one member is selected
|
||||
if selected_members:
|
||||
st.markdown("---")
|
||||
|
||||
# Get statistics from cache (already loaded)
|
||||
grid_stats = all_stats[selected_grid_config.id]
|
||||
|
||||
|
|
@ -301,12 +223,14 @@ def render_feature_count_explorer():
|
|||
|
||||
breakdown_df = pd.DataFrame(breakdown_data)
|
||||
|
||||
# Get color palette for data sources
|
||||
n_sources = len(breakdown_df)
|
||||
source_colors = get_palette("data_sources", n_colors=n_sources)
|
||||
# Get all unique data sources and create color map
|
||||
unique_sources = sorted(breakdown_df["Data Source"].unique())
|
||||
n_sources = len(unique_sources)
|
||||
source_color_list = get_palette("data_sources", n_colors=n_sources)
|
||||
source_color_map = dict(zip(unique_sources, source_color_list))
|
||||
|
||||
# Create and display pie chart
|
||||
fig = create_feature_distribution_pie(breakdown_df, source_colors=source_colors)
|
||||
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
# Show detailed table
|
||||
|
|
@ -363,38 +287,82 @@ def render_feature_count_explorer():
|
|||
st.info("👆 Select at least one data source to see feature statistics")
|
||||
|
||||
|
||||
def render_feature_count_section():
|
||||
"""Render the feature count section with comparison and explorer."""
|
||||
st.subheader("🔢 Feature Counts by Dataset Configuration")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
This visualization shows the total number of features that would be generated
|
||||
for different combinations of data sources and grid configurations.
|
||||
"""
|
||||
)
|
||||
|
||||
# Static comparison across all grids
|
||||
render_feature_count_comparison()
|
||||
|
||||
st.divider()
|
||||
|
||||
# Interactive explorer for detailed analysis
|
||||
render_feature_count_explorer()
|
||||
|
||||
|
||||
def render_dataset_analysis():
|
||||
"""Render the dataset analysis section with sample and feature counts."""
|
||||
st.header("📈 Dataset Analysis")
|
||||
|
||||
# Create tabs for the two different analyses
|
||||
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
|
||||
# Create tabs for different analysis views
|
||||
analysis_tabs = st.tabs(
|
||||
[
|
||||
"📊 Training Samples",
|
||||
"📈 Dataset Characteristics",
|
||||
"🔍 Feature Breakdown",
|
||||
"⚙️ Configuration Explorer",
|
||||
]
|
||||
)
|
||||
|
||||
with analysis_tabs[0]:
|
||||
st.subheader("Training Samples by Configuration")
|
||||
render_sample_count_overview()
|
||||
|
||||
with analysis_tabs[1]:
|
||||
render_feature_count_section()
|
||||
st.subheader("Dataset Characteristics Across Grid Configurations")
|
||||
render_feature_count_comparison()
|
||||
|
||||
with analysis_tabs[2]:
|
||||
st.subheader("Feature Breakdown by Data Source")
|
||||
# Get data from cache
|
||||
all_stats = load_all_default_dataset_statistics()
|
||||
comparison_df = DatasetStatistics.get_feature_count_df(all_stats)
|
||||
breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
|
||||
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
|
||||
|
||||
# Get all unique data sources and create color map
|
||||
unique_sources = sorted(breakdown_df["Data Source"].unique())
|
||||
n_sources = len(unique_sources)
|
||||
source_color_list = get_palette("data_sources", n_colors=n_sources)
|
||||
source_color_map = dict(zip(unique_sources, source_color_list))
|
||||
|
||||
st.markdown("Showing percentage contribution of each data source across all grid configurations")
|
||||
|
||||
# Sparse Resolution girds
|
||||
for res in ["sparse", "low", "medium"]:
|
||||
cols = st.columns(2)
|
||||
with cols[0]:
|
||||
grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "hex"]
|
||||
for gc in grid_configs_res:
|
||||
grid_display = gc.display_name
|
||||
grid_data = breakdown_df[breakdown_df["Grid"] == grid_display]
|
||||
fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map)
|
||||
st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}")
|
||||
with cols[1]:
|
||||
grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "healpix"]
|
||||
for gc in grid_configs_res:
|
||||
grid_display = gc.display_name
|
||||
grid_data = breakdown_df[breakdown_df["Grid"] == grid_display]
|
||||
fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map)
|
||||
st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}")
|
||||
|
||||
# Create donut charts for each grid configuration
|
||||
# num_grids = len(comparison_df)
|
||||
# cols_per_row = 3
|
||||
# num_rows = (num_grids + cols_per_row - 1) // cols_per_row
|
||||
|
||||
# for row_idx in range(num_rows):
|
||||
# cols = st.columns(cols_per_row)
|
||||
# for col_idx in range(cols_per_row):
|
||||
# grid_idx = row_idx * cols_per_row + col_idx
|
||||
# if grid_idx < num_grids:
|
||||
# grid_config = comparison_df.iloc[grid_idx]["Grid"]
|
||||
# grid_data = breakdown_df[breakdown_df["Grid"] == grid_config]
|
||||
|
||||
# with cols[col_idx]:
|
||||
# fig = create_feature_breakdown_donut(grid_data, grid_config, source_color_map=source_color_map)
|
||||
# st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
with analysis_tabs[3]:
|
||||
st.subheader("Interactive Configuration Explorer")
|
||||
render_feature_count_explorer()
|
||||
|
||||
|
||||
def render_training_results_summary(training_results):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
"""Training Results Analysis page: Analysis of training results and model performance."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import streamlit as st
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.dashboard.plots.hyperparameter_analysis import (
|
||||
render_binned_parameter_space,
|
||||
render_confusion_matrix_heatmap,
|
||||
render_confusion_matrix_map,
|
||||
render_espa_binned_parameter_space,
|
||||
render_multi_metric_comparison,
|
||||
|
|
@ -14,12 +17,12 @@ from entropice.dashboard.plots.hyperparameter_analysis import (
|
|||
render_top_configurations,
|
||||
)
|
||||
from entropice.dashboard.utils.formatters import format_metric_name
|
||||
from entropice.dashboard.utils.loaders import load_all_training_results
|
||||
from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results
|
||||
from entropice.dashboard.utils.stats import CVResultsStatistics
|
||||
from entropice.utils.types import GridConfig
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_analysis_settings_sidebar(training_results):
|
||||
def render_analysis_settings_sidebar(training_results: list[TrainingResult]) -> tuple[TrainingResult, str, str, int]:
|
||||
"""Render sidebar for training run and analysis settings selection.
|
||||
|
||||
Args:
|
||||
|
|
@ -42,7 +45,7 @@ def render_analysis_settings_sidebar(training_results):
|
|||
key="training_run_select",
|
||||
)
|
||||
|
||||
selected_result = training_options[selected_name]
|
||||
selected_result = cast(TrainingResult, training_options[selected_name])
|
||||
|
||||
st.divider()
|
||||
|
||||
|
|
@ -52,22 +55,7 @@ def render_analysis_settings_sidebar(training_results):
|
|||
available_metrics = selected_result.available_metrics
|
||||
|
||||
# Try to get refit metric from settings
|
||||
refit_metric = selected_result.settings.refit_metric if hasattr(selected_result.settings, "refit_metric") else None
|
||||
|
||||
if not refit_metric or refit_metric not in available_metrics:
|
||||
# Infer from task or use first available metric
|
||||
task = selected_result.settings.task
|
||||
if task == "binary" and "f1" in available_metrics:
|
||||
refit_metric = "f1"
|
||||
elif "f1_weighted" in available_metrics:
|
||||
refit_metric = "f1_weighted"
|
||||
elif "accuracy" in available_metrics:
|
||||
refit_metric = "accuracy"
|
||||
elif available_metrics:
|
||||
refit_metric = available_metrics[0]
|
||||
else:
|
||||
st.error("No metrics found in results.")
|
||||
return None, None, None, None
|
||||
refit_metric = "f1" if selected_result.settings.task == "binary" else "f1_weighted"
|
||||
|
||||
if refit_metric in available_metrics:
|
||||
default_metric_idx = available_metrics.index(refit_metric)
|
||||
|
|
@ -97,7 +85,7 @@ def render_analysis_settings_sidebar(training_results):
|
|||
return selected_result, selected_metric, refit_metric, top_n
|
||||
|
||||
|
||||
def render_run_information(selected_result, refit_metric):
|
||||
def render_run_information(selected_result: TrainingResult, refit_metric):
|
||||
"""Render training run configuration overview.
|
||||
|
||||
Args:
|
||||
|
|
@ -107,23 +95,86 @@ def render_run_information(selected_result, refit_metric):
|
|||
"""
|
||||
st.header("📋 Run Information")
|
||||
|
||||
col1, col2, col3, col4, col5, col6 = st.columns(6)
|
||||
grid_config = GridConfig.from_grid_level(f"{selected_result.settings.grid}{selected_result.settings.level}") # ty:ignore[invalid-argument-type]
|
||||
|
||||
col1, col2, col3, col4, col5 = st.columns(5)
|
||||
with col1:
|
||||
st.metric("Task", selected_result.settings.task.capitalize())
|
||||
with col2:
|
||||
st.metric("Grid", selected_result.settings.grid.capitalize())
|
||||
st.metric("Target", selected_result.settings.target.capitalize())
|
||||
with col3:
|
||||
st.metric("Level", selected_result.settings.level)
|
||||
st.metric("Grid", grid_config.display_name)
|
||||
with col4:
|
||||
st.metric("Model", selected_result.settings.model.upper())
|
||||
with col5:
|
||||
st.metric("Trials", len(selected_result.results))
|
||||
with col6:
|
||||
st.metric("CV Splits", selected_result.settings.cv_splits)
|
||||
|
||||
st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}")
|
||||
|
||||
|
||||
def render_test_metrics_section(selected_result: TrainingResult):
|
||||
"""Render test metrics overview showing final model performance.
|
||||
|
||||
Args:
|
||||
selected_result: The selected TrainingResult object.
|
||||
|
||||
"""
|
||||
st.header("🎯 Test Set Performance")
|
||||
st.caption("Performance metrics on the held-out test set (best model from hyperparameter search)")
|
||||
|
||||
test_metrics = selected_result.metrics
|
||||
|
||||
if not test_metrics:
|
||||
st.warning("No test metrics available for this training run.")
|
||||
return
|
||||
|
||||
# Display metrics in columns based on task type
|
||||
task = selected_result.settings.task
|
||||
|
||||
if task == "binary":
|
||||
# Binary classification metrics
|
||||
col1, col2, col3, col4, col5 = st.columns(5)
|
||||
|
||||
with col1:
|
||||
st.metric("Accuracy", f"{test_metrics.get('accuracy', 0):.4f}")
|
||||
with col2:
|
||||
st.metric("F1 Score", f"{test_metrics.get('f1', 0):.4f}")
|
||||
with col3:
|
||||
st.metric("Precision", f"{test_metrics.get('precision', 0):.4f}")
|
||||
with col4:
|
||||
st.metric("Recall", f"{test_metrics.get('recall', 0):.4f}")
|
||||
with col5:
|
||||
st.metric("Jaccard", f"{test_metrics.get('jaccard', 0):.4f}")
|
||||
else:
|
||||
# Multiclass metrics
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
st.metric("Accuracy", f"{test_metrics.get('accuracy', 0):.4f}")
|
||||
with col2:
|
||||
st.metric("F1 (Macro)", f"{test_metrics.get('f1_macro', 0):.4f}")
|
||||
with col3:
|
||||
st.metric("F1 (Weighted)", f"{test_metrics.get('f1_weighted', 0):.4f}")
|
||||
|
||||
col4, col5, col6 = st.columns(3)
|
||||
|
||||
with col4:
|
||||
st.metric("Precision (Macro)", f"{test_metrics.get('precision_macro', 0):.4f}")
|
||||
with col5:
|
||||
st.metric("Precision (Weighted)", f"{test_metrics.get('precision_weighted', 0):.4f}")
|
||||
with col6:
|
||||
st.metric("Recall (Macro)", f"{test_metrics.get('recall_macro', 0):.4f}")
|
||||
|
||||
col7, col8, col9 = st.columns(3)
|
||||
|
||||
with col7:
|
||||
st.metric("Jaccard (Micro)", f"{test_metrics.get('jaccard_micro', 0):.4f}")
|
||||
with col8:
|
||||
st.metric("Jaccard (Macro)", f"{test_metrics.get('jaccard_macro', 0):.4f}")
|
||||
with col9:
|
||||
st.metric("Jaccard (Weighted)", f"{test_metrics.get('jaccard_weighted', 0):.4f}")
|
||||
|
||||
|
||||
def render_cv_statistics_section(selected_result, selected_metric):
|
||||
"""Render cross-validation statistics for selected metric.
|
||||
|
||||
|
|
@ -133,6 +184,7 @@ def render_cv_statistics_section(selected_result, selected_metric):
|
|||
|
||||
"""
|
||||
st.header("📈 Cross-Validation Statistics")
|
||||
st.caption("Performance during hyperparameter search (averaged across CV folds)")
|
||||
|
||||
from entropice.dashboard.utils.stats import CVMetricStatistics
|
||||
|
||||
|
|
@ -158,6 +210,45 @@ def render_cv_statistics_section(selected_result, selected_metric):
|
|||
if cv_stats.mean_cv_std is not None:
|
||||
st.info(f"**Mean CV Std:** {cv_stats.mean_cv_std:.4f} - Average standard deviation across CV folds")
|
||||
|
||||
# Compare with test metric if available
|
||||
if selected_metric in selected_result.metrics:
|
||||
test_score = selected_result.metrics[selected_metric]
|
||||
st.divider()
|
||||
st.subheader("CV vs Test Performance")
|
||||
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Best CV Score", f"{cv_stats.best_score:.4f}")
|
||||
with col2:
|
||||
st.metric("Test Score", f"{test_score:.4f}")
|
||||
with col3:
|
||||
delta = test_score - cv_stats.best_score
|
||||
delta_pct = (delta / cv_stats.best_score * 100) if cv_stats.best_score != 0 else 0
|
||||
st.metric("Difference", f"{delta:+.4f}", delta=f"{delta_pct:+.2f}%")
|
||||
|
||||
if abs(delta) > cv_stats.std_score:
|
||||
st.warning(
|
||||
"⚠️ Test performance differs significantly from CV performance. "
|
||||
"This may indicate overfitting or data distribution mismatch."
|
||||
)
|
||||
|
||||
|
||||
def render_confusion_matrix_section(selected_result: TrainingResult):
|
||||
"""Render confusion matrix visualization and analysis.
|
||||
|
||||
Args:
|
||||
selected_result: The selected TrainingResult object.
|
||||
|
||||
"""
|
||||
st.header("🎲 Confusion Matrix")
|
||||
st.caption("Detailed breakdown of predictions on the test set")
|
||||
|
||||
if selected_result.confusion_matrix is None:
|
||||
st.warning("No confusion matrix available for this training run.")
|
||||
return
|
||||
|
||||
render_confusion_matrix_heatmap(selected_result.confusion_matrix, selected_result.settings.task)
|
||||
|
||||
|
||||
def render_parameter_space_section(selected_result, selected_metric):
|
||||
"""Render parameter space analysis section.
|
||||
|
|
@ -292,8 +383,19 @@ def render_training_analysis_page():
|
|||
|
||||
st.divider()
|
||||
|
||||
# Test Metrics Section
|
||||
render_test_metrics_section(selected_result)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Confusion Matrix Section
|
||||
render_confusion_matrix_section(selected_result)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Performance Summary Section
|
||||
st.header("📊 Performance Overview")
|
||||
st.header("📊 CV Performance Overview")
|
||||
st.caption("Summary of hyperparameter search results across all configurations")
|
||||
render_performance_summary(results, refit_metric)
|
||||
|
||||
st.divider()
|
||||
|
|
|
|||
|
|
@ -320,6 +320,7 @@ def render_arcticdem_view(ensemble: DatasetEnsemble, arcticdem_ds, targets):
|
|||
render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_era5_view(ensemble: DatasetEnsemble, era5_data: dict[L2SourceDataset, tuple], targets):
|
||||
"""Render ERA5 climate data analysis.
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ import json
|
|||
from collections.abc import Generator
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from functools import cached_property
|
||||
from itertools import product
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
import cupy as cp
|
||||
|
|
@ -30,6 +31,7 @@ import xarray as xr
|
|||
from rich import pretty, traceback
|
||||
from sklearn import set_config
|
||||
from sklearn.model_selection import train_test_split
|
||||
from stopuhr import stopwatch
|
||||
|
||||
import entropice.utils.paths
|
||||
from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task
|
||||
|
|
@ -295,6 +297,30 @@ class DatasetEnsemble:
|
|||
temporal: Literal["yearly", "seasonal", "shoulder"],
|
||||
) -> pd.DataFrame:
|
||||
era5 = self._read_member("ERA5-" + temporal, targets)
|
||||
|
||||
if len(era5["cell_ids"]) == 0:
|
||||
# No data for these cells - create empty DataFrame with expected columns
|
||||
# Use the Dataset metadata to determine column structure
|
||||
variables = list(era5.data_vars)
|
||||
times = era5.coords["time"].to_numpy()
|
||||
time_df = pd.DataFrame({"time": times})
|
||||
time_df.index = pd.DatetimeIndex(times)
|
||||
tempus = _get_era5_tempus(time_df, temporal)
|
||||
unique_tempus = tempus.unique()
|
||||
|
||||
if "aggregations" in era5.dims:
|
||||
aggs_list = era5.coords["aggregations"].to_numpy()
|
||||
expected_cols = [
|
||||
f"era5_{var}_{t}_{agg}" for var, t, agg in product(variables, unique_tempus, aggs_list)
|
||||
]
|
||||
else:
|
||||
expected_cols = [f"era5_{var}_{t}" for var, t in product(variables, unique_tempus)]
|
||||
|
||||
return pd.DataFrame(
|
||||
index=targets["cell_id"].values,
|
||||
columns=expected_cols,
|
||||
dtype=float,
|
||||
)
|
||||
era5_df = era5.to_dataframe()
|
||||
era5_df["t"] = _get_era5_tempus(era5_df, temporal)
|
||||
if "aggregations" not in era5.dims:
|
||||
|
|
@ -303,20 +329,50 @@ class DatasetEnsemble:
|
|||
else:
|
||||
era5_df = era5_df.pivot_table(index="cell_ids", columns=["t", "aggregations"])
|
||||
era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns]
|
||||
# Ensure all target cell_ids are present, fill missing with NaN
|
||||
era5_df = era5_df.reindex(targets["cell_id"].values, fill_value=np.nan)
|
||||
return era5_df
|
||||
|
||||
def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
|
||||
embeddings = self._read_member("AlphaEarth", targets)["embeddings"]
|
||||
|
||||
if len(embeddings["cell_ids"]) == 0:
|
||||
# No data for these cells - create empty DataFrame with expected columns
|
||||
# Use the Dataset metadata to determine column structure
|
||||
years = embeddings.coords["year"].to_numpy()
|
||||
aggs = embeddings.coords["agg"].to_numpy()
|
||||
bands = embeddings.coords["band"].to_numpy()
|
||||
expected_cols = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in product(years, aggs, bands)]
|
||||
return pd.DataFrame(
|
||||
index=targets["cell_id"].values,
|
||||
columns=expected_cols,
|
||||
dtype=float,
|
||||
)
|
||||
embeddings_df = embeddings.to_dataframe(name="value")
|
||||
embeddings_df = embeddings_df.pivot_table(index="cell_ids", columns=["year", "agg", "band"], values="value")
|
||||
embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
|
||||
# Ensure all target cell_ids are present, fill missing with NaN
|
||||
embeddings_df = embeddings_df.reindex(targets["cell_id"].values, fill_value=np.nan)
|
||||
return embeddings_df
|
||||
|
||||
def _prep_arcticdem(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
|
||||
arcticdem = self._read_member("ArcticDEM", targets)
|
||||
|
||||
if len(arcticdem["cell_ids"]) == 0:
|
||||
# No data for these cells - create empty DataFrame with expected columns
|
||||
# Use the Dataset metadata to determine column structure
|
||||
variables = list(arcticdem.data_vars)
|
||||
aggs = arcticdem.coords["aggregations"].to_numpy()
|
||||
expected_cols = [f"arcticdem_{var}_{agg}" for var, agg in product(variables, aggs)]
|
||||
return pd.DataFrame(
|
||||
index=targets["cell_id"].values,
|
||||
columns=expected_cols,
|
||||
dtype=float,
|
||||
)
|
||||
arcticdem_df = arcticdem.to_dataframe().pivot_table(index="cell_ids", columns="aggregations")
|
||||
arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns]
|
||||
# Ensure all target cell_ids are present, fill missing with NaN
|
||||
arcticdem_df = arcticdem_df.reindex(targets["cell_id"].values, fill_value=np.nan)
|
||||
return arcticdem_df
|
||||
|
||||
def get_stats(self) -> DatasetStats:
|
||||
|
|
@ -374,33 +430,39 @@ class DatasetEnsemble:
|
|||
# n: no cache, o: overwrite cache, r: read cache if exists
|
||||
cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col)
|
||||
if cache_mode == "r" and cache_file.exists():
|
||||
dataset = gpd.read_parquet(cache_file)
|
||||
with stopwatch("Loading dataset from cache"):
|
||||
dataset = gpd.read_parquet(cache_file)
|
||||
print(
|
||||
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
|
||||
f" and {len(dataset.columns)} features."
|
||||
)
|
||||
return dataset
|
||||
targets = self._read_target()
|
||||
if filter_target_col is not None:
|
||||
targets = targets.loc[targets[filter_target_col]]
|
||||
|
||||
member_dfs = []
|
||||
for member in self.members:
|
||||
if member.startswith("ERA5"):
|
||||
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
|
||||
member_dfs.append(self._prep_era5(targets, era5_agg))
|
||||
elif member == "AlphaEarth":
|
||||
member_dfs.append(self._prep_embeddings(targets))
|
||||
elif member == "ArcticDEM":
|
||||
member_dfs.append(self._prep_arcticdem(targets))
|
||||
else:
|
||||
raise NotImplementedError(f"Member {member} not implemented.")
|
||||
|
||||
dataset = targets.set_index("cell_id").join(member_dfs)
|
||||
with stopwatch("Reading target"):
|
||||
targets = self._read_target()
|
||||
if filter_target_col is not None:
|
||||
targets = targets.loc[targets[filter_target_col]]
|
||||
print(f"Read and filtered target dataset. ({len(targets)} samples)")
|
||||
with stopwatch("Preparing member datasets"):
|
||||
member_dfs = []
|
||||
for member in self.members:
|
||||
if member.startswith("ERA5"):
|
||||
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
|
||||
member_dfs.append(self._prep_era5(targets, era5_agg))
|
||||
elif member == "AlphaEarth":
|
||||
member_dfs.append(self._prep_embeddings(targets))
|
||||
elif member == "ArcticDEM":
|
||||
member_dfs.append(self._prep_arcticdem(targets))
|
||||
else:
|
||||
raise NotImplementedError(f"Member {member} not implemented.")
|
||||
print("Prepared all member datasets. Joining...")
|
||||
with stopwatch("Joining datasets"):
|
||||
dataset = targets.set_index("cell_id").join(member_dfs)
|
||||
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
||||
print("Joining complete.")
|
||||
|
||||
if cache_mode in ["o", "r"]:
|
||||
dataset.to_parquet(cache_file)
|
||||
with stopwatch("Saving dataset to cache"):
|
||||
dataset.to_parquet(cache_file)
|
||||
print(f"Saved dataset to cache at {cache_file}.")
|
||||
return dataset
|
||||
|
||||
|
|
|
|||
|
|
@ -46,6 +46,11 @@ def predict_proba(
|
|||
else:
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
||||
X_batch = batch.drop(columns=cols_to_drop).dropna()
|
||||
|
||||
# Skip empty batches (all rows had NaN values)
|
||||
if len(X_batch) == 0:
|
||||
continue
|
||||
|
||||
cell_ids = X_batch.index.to_numpy()
|
||||
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
|
||||
X_batch = X_batch.to_numpy(dtype="float64")
|
||||
|
|
|
|||
|
|
@ -2,12 +2,14 @@
|
|||
|
||||
import pickle
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import partial
|
||||
|
||||
import cupy as cp
|
||||
import cyclopts
|
||||
import pandas as pd
|
||||
import toml
|
||||
import xarray as xr
|
||||
from array_api_compat import get_namespace
|
||||
from cuml.ensemble import RandomForestClassifier
|
||||
from cuml.neighbors import KNeighborsClassifier
|
||||
from entropy import ESPAClassifier
|
||||
|
|
@ -15,6 +17,14 @@ from rich import pretty, traceback
|
|||
from scipy.stats import loguniform, randint
|
||||
from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen
|
||||
from sklearn import set_config
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
confusion_matrix,
|
||||
f1_score,
|
||||
jaccard_score,
|
||||
precision_score,
|
||||
recall_score,
|
||||
)
|
||||
from sklearn.model_selection import KFold, RandomizedSearchCV
|
||||
from stopuhr import stopwatch
|
||||
from xgboost.sklearn import XGBClassifier
|
||||
|
|
@ -49,6 +59,26 @@ _metrics = {
|
|||
}
|
||||
|
||||
|
||||
# Compute other metrics - using predictions directly instead of re-predicting for each metric
|
||||
# Use functools.partial for cleaner metric definitions with non-default parameters
|
||||
_metric_functions = {
|
||||
"accuracy": accuracy_score,
|
||||
"recall": recall_score,
|
||||
"precision": precision_score,
|
||||
"f1": f1_score,
|
||||
"jaccard": jaccard_score,
|
||||
"recall_macro": partial(recall_score, average="macro"),
|
||||
"recall_weighted": partial(recall_score, average="weighted"),
|
||||
"precision_macro": partial(precision_score, average="macro"),
|
||||
"precision_weighted": partial(precision_score, average="weighted"),
|
||||
"f1_macro": partial(f1_score, average="macro"),
|
||||
"f1_weighted": partial(f1_score, average="weighted"),
|
||||
"jaccard_micro": partial(jaccard_score, average="micro"),
|
||||
"jaccard_macro": partial(jaccard_score, average="macro"),
|
||||
"jaccard_weighted": partial(jaccard_score, average="weighted"),
|
||||
}
|
||||
|
||||
|
||||
@cyclopts.Parameter("*")
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class CVSettings:
|
||||
|
|
@ -200,6 +230,26 @@ def random_cv(
|
|||
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
|
||||
print(f"Accuracy on test set: {test_accuracy:.3f}")
|
||||
|
||||
# Compute predictions on the test set
|
||||
y_pred = best_estimator.predict(training_data.X.test)
|
||||
labels = list(range(len(training_data.y.labels)))
|
||||
xp = get_namespace(y_test)
|
||||
y_test = xp.as_array(y_test)
|
||||
y_pred = xp.as_array(y_pred)
|
||||
labels = xp.as_array(labels)
|
||||
|
||||
test_metrics = {metric: _metric_functions[metric](y_test, y_pred) for metric in metrics}
|
||||
|
||||
# Get a confusion matrix
|
||||
cm = confusion_matrix(y_test, y_pred, labels=labels)
|
||||
label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))]
|
||||
cm = xr.DataArray(
|
||||
xp.as_numpy(cm),
|
||||
dims=["true_label", "predicted_label"],
|
||||
coords={"true_label": label_names, "predicted_label": label_names},
|
||||
name="confusion_matrix",
|
||||
)
|
||||
|
||||
results_dir = get_cv_results_dir(
|
||||
"random_search",
|
||||
grid=dataset_ensemble.grid,
|
||||
|
|
@ -239,6 +289,17 @@ def random_cv(
|
|||
print(f"Storing CV results to {results_file}")
|
||||
results.to_parquet(results_file)
|
||||
|
||||
# Store the test metrics
|
||||
test_metrics_file = results_dir / "test_metrics.toml"
|
||||
print(f"Storing test metrics to {test_metrics_file}")
|
||||
with open(test_metrics_file, "w") as f:
|
||||
toml.dump({"test_metrics": test_metrics}, f)
|
||||
|
||||
# Store the confusion matrix
|
||||
cm_file = results_dir / "confusion_matrix.nc"
|
||||
print(f"Storing confusion matrix to {cm_file}")
|
||||
cm.to_netcdf(cm_file, engine="h5netcdf")
|
||||
|
||||
# Get the inner state of the best estimator
|
||||
features = training_data.X.data.columns.tolist()
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ class GridConfig:
|
|||
level: int
|
||||
id: GridLevel
|
||||
display_name: str
|
||||
res: Literal["sparse", "low", "medium"]
|
||||
sort_key: str
|
||||
|
||||
@classmethod
|
||||
|
|
@ -46,11 +47,24 @@ class GridConfig:
|
|||
|
||||
display_name = f"{grid.capitalize()}-{level}"
|
||||
|
||||
resmap: dict[str, Literal["sparse", "low", "medium"]] = {
|
||||
"hex3": "sparse",
|
||||
"hex4": "sparse",
|
||||
"hex5": "low",
|
||||
"hex6": "medium",
|
||||
"healpix6": "sparse",
|
||||
"healpix7": "sparse",
|
||||
"healpix8": "low",
|
||||
"healpix9": "low",
|
||||
"healpix10": "medium",
|
||||
}
|
||||
|
||||
return cls(
|
||||
grid=grid,
|
||||
level=level,
|
||||
id=grid_level,
|
||||
display_name=display_name,
|
||||
res=resmap[grid_level],
|
||||
sort_key=f"{grid}_{level:02d}",
|
||||
)
|
||||
|
||||
|
|
|
|||
33
tests/debug_arcticdem_batch.py
Normal file
33
tests/debug_arcticdem_batch.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
"""Debug script to check what _prep_arcticdem returns for a batch."""
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
|
||||
ensemble = DatasetEnsemble(
|
||||
grid="healpix",
|
||||
level=10,
|
||||
target="darts_mllabels",
|
||||
members=["ArcticDEM"],
|
||||
add_lonlat=True,
|
||||
filter_target=False,
|
||||
)
|
||||
|
||||
# Get targets
|
||||
targets = ensemble._read_target()
|
||||
print(f"Total targets: {len(targets)}")
|
||||
|
||||
# Get first batch of targets
|
||||
batch_targets = targets.iloc[:100]
|
||||
print(f"\nBatch targets: {len(batch_targets)}")
|
||||
print(f"Cell IDs in batch: {batch_targets['cell_id'].values[:5]}")
|
||||
|
||||
# Try to prep ArcticDEM for this batch
|
||||
print("\n" + "=" * 80)
|
||||
print("Calling _prep_arcticdem...")
|
||||
print("=" * 80)
|
||||
arcticdem_df = ensemble._prep_arcticdem(batch_targets)
|
||||
print(f"\nArcticDEM DataFrame shape: {arcticdem_df.shape}")
|
||||
print(f"ArcticDEM DataFrame index: {arcticdem_df.index[:5].tolist() if len(arcticdem_df) > 0 else 'EMPTY'}")
|
||||
print(
|
||||
f"ArcticDEM DataFrame columns ({len(arcticdem_df.columns)}): {arcticdem_df.columns[:10].tolist() if len(arcticdem_df.columns) > 0 else 'NO COLUMNS'}"
|
||||
)
|
||||
print(f"Number of non-NaN rows: {arcticdem_df.notna().any(axis=1).sum()}")
|
||||
72
tests/debug_feature_mismatch.py
Normal file
72
tests/debug_feature_mismatch.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""Debug script to identify feature mismatch between training and inference."""
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
|
||||
# Test with level 6 (the actual level used in production)
|
||||
ensemble = DatasetEnsemble(
|
||||
grid="healpix",
|
||||
level=10,
|
||||
target="darts_mllabels",
|
||||
members=[
|
||||
"AlphaEarth",
|
||||
"ArcticDEM",
|
||||
"ERA5-yearly",
|
||||
"ERA5-seasonal",
|
||||
"ERA5-shoulder",
|
||||
],
|
||||
add_lonlat=True,
|
||||
filter_target=False,
|
||||
)
|
||||
|
||||
print("=" * 80)
|
||||
print("Creating training dataset...")
|
||||
print("=" * 80)
|
||||
training_data = ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||||
training_features = set(training_data.X.data.columns)
|
||||
print(f"\nTraining dataset created with {len(training_features)} features")
|
||||
print(f"Sample features: {sorted(list(training_features))[:10]}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Creating inference batch...")
|
||||
print("=" * 80)
|
||||
batch_generator = ensemble.create_batches(batch_size=100, cache_mode="n")
|
||||
batch = next(batch_generator, None)
|
||||
# for batch in batch_generator:
|
||||
if batch is None:
|
||||
print("ERROR: No batch created!")
|
||||
else:
|
||||
print(f"\nBatch created with {len(batch.columns)} columns")
|
||||
print(f"Batch columns: {sorted(batch.columns)[:15]}")
|
||||
|
||||
# Simulate the column dropping in predict_proba (inference.py)
|
||||
cols_to_drop = ["geometry"]
|
||||
if ensemble.target == "darts_mllabels":
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
||||
else:
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
||||
|
||||
print(f"\nColumns to drop: {cols_to_drop}")
|
||||
|
||||
inference_batch = batch.drop(columns=cols_to_drop)
|
||||
inference_features = set(inference_batch.columns)
|
||||
|
||||
print(f"\nInference batch after dropping has {len(inference_features)} features")
|
||||
print(f"Sample features: {sorted(list(inference_features))[:10]}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("COMPARISON")
|
||||
print("=" * 80)
|
||||
print(f"Training features: {len(training_features)}")
|
||||
print(f"Inference features: {len(inference_features)}")
|
||||
|
||||
if training_features == inference_features:
|
||||
print("\n✅ SUCCESS: Features match perfectly!")
|
||||
else:
|
||||
print("\n❌ MISMATCH DETECTED!")
|
||||
only_in_training = training_features - inference_features
|
||||
only_in_inference = inference_features - training_features
|
||||
|
||||
if only_in_training:
|
||||
print(f"\n⚠️ Only in TRAINING ({len(only_in_training)}): {sorted(only_in_training)}")
|
||||
if only_in_inference:
|
||||
print(f"\n⚠️ Only in INFERENCE ({len(only_in_inference)}): {sorted(only_in_inference)}")
|
||||
310
tests/test_dataset.py
Normal file
310
tests/test_dataset.py
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
"""Tests for dataset.py module, specifically DatasetEnsemble class."""
|
||||
|
||||
import geopandas as gpd
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ensemble():
|
||||
"""Create a sample DatasetEnsemble for testing with minimal data."""
|
||||
return DatasetEnsemble(
|
||||
grid="hex",
|
||||
level=3, # Use level 3 for much faster tests
|
||||
target="darts_rts",
|
||||
members=["AlphaEarth"], # Use only one member for faster tests
|
||||
add_lonlat=True,
|
||||
filter_target="darts_has_coverage", # Filter to reduce dataset size
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ensemble_mllabels():
|
||||
"""Create a sample DatasetEnsemble with mllabels target."""
|
||||
return DatasetEnsemble(
|
||||
grid="hex",
|
||||
level=3, # Use level 3 for much faster tests
|
||||
target="darts_mllabels",
|
||||
members=["AlphaEarth"], # Use only one member for faster tests
|
||||
add_lonlat=True,
|
||||
filter_target="dartsml_has_coverage", # Filter to reduce dataset size
|
||||
)
|
||||
|
||||
|
||||
class TestDatasetEnsemble:
|
||||
"""Test suite for DatasetEnsemble class."""
|
||||
|
||||
def test_initialization(self, sample_ensemble):
|
||||
"""Test that DatasetEnsemble initializes correctly."""
|
||||
assert sample_ensemble.grid == "hex"
|
||||
assert sample_ensemble.level == 3
|
||||
assert sample_ensemble.target == "darts_rts"
|
||||
assert "AlphaEarth" in sample_ensemble.members
|
||||
assert sample_ensemble.add_lonlat is True
|
||||
|
||||
def test_covcol_property(self, sample_ensemble, sample_ensemble_mllabels):
|
||||
"""Test that covcol property returns correct column name."""
|
||||
assert sample_ensemble.covcol == "darts_has_coverage"
|
||||
assert sample_ensemble_mllabels.covcol == "dartsml_has_coverage"
|
||||
|
||||
def test_taskcol_property(self, sample_ensemble, sample_ensemble_mllabels):
|
||||
"""Test that taskcol returns correct column name for different tasks."""
|
||||
assert sample_ensemble.taskcol("binary") == "darts_has_rts"
|
||||
assert sample_ensemble.taskcol("count") == "darts_rts_count"
|
||||
assert sample_ensemble.taskcol("density") == "darts_rts_density"
|
||||
|
||||
assert sample_ensemble_mllabels.taskcol("binary") == "dartsml_has_rts"
|
||||
assert sample_ensemble_mllabels.taskcol("count") == "dartsml_rts_count"
|
||||
assert sample_ensemble_mllabels.taskcol("density") == "dartsml_rts_density"
|
||||
|
||||
def test_create_returns_geodataframe(self, sample_ensemble):
|
||||
"""Test that create() returns a GeoDataFrame."""
|
||||
dataset = sample_ensemble.create(cache_mode="n")
|
||||
assert isinstance(dataset, gpd.GeoDataFrame)
|
||||
|
||||
def test_create_has_expected_columns(self, sample_ensemble):
|
||||
"""Test that create() returns dataset with expected columns."""
|
||||
dataset = sample_ensemble.create(cache_mode="n")
|
||||
|
||||
# Should have geometry column
|
||||
assert "geometry" in dataset.columns
|
||||
|
||||
# Should have lat/lon if add_lonlat is True
|
||||
if sample_ensemble.add_lonlat:
|
||||
assert "lon" in dataset.columns
|
||||
assert "lat" in dataset.columns
|
||||
|
||||
# Should have target columns (darts_*)
|
||||
assert any(col.startswith("darts_") for col in dataset.columns)
|
||||
|
||||
# Should have member data columns
|
||||
assert len(dataset.columns) > 3 # More than just geometry, lat, lon
|
||||
|
||||
def test_create_batches_consistency(self, sample_ensemble):
|
||||
"""Test that create_batches produces batches with consistent columns."""
|
||||
batch_size = 100
|
||||
batches = list(sample_ensemble.create_batches(batch_size=batch_size, cache_mode="n"))
|
||||
|
||||
if len(batches) == 0:
|
||||
pytest.skip("No batches created (dataset might be empty)")
|
||||
|
||||
# All batches should have the same columns
|
||||
first_batch_cols = set(batches[0].columns)
|
||||
for i, batch in enumerate(batches[1:], start=1):
|
||||
assert set(batch.columns) == first_batch_cols, (
|
||||
f"Batch {i} has different columns than batch 0. "
|
||||
f"Difference: {set(batch.columns).symmetric_difference(first_batch_cols)}"
|
||||
)
|
||||
|
||||
def test_create_vs_create_batches_columns(self, sample_ensemble):
|
||||
"""Test that create() and create_batches() return datasets with the same columns."""
|
||||
full_dataset = sample_ensemble.create(cache_mode="n")
|
||||
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n"))
|
||||
|
||||
if len(batches) == 0:
|
||||
pytest.skip("No batches created (dataset might be empty)")
|
||||
|
||||
# Columns should be identical
|
||||
full_cols = set(full_dataset.columns)
|
||||
batch_cols = set(batches[0].columns)
|
||||
|
||||
assert full_cols == batch_cols, (
|
||||
f"Column mismatch between create() and create_batches().\n"
|
||||
f"Only in create(): {full_cols - batch_cols}\n"
|
||||
f"Only in create_batches(): {batch_cols - full_cols}"
|
||||
)
|
||||
|
||||
def test_training_dataset_feature_columns(self, sample_ensemble):
|
||||
"""Test that create_cat_training_dataset creates proper feature columns."""
|
||||
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||||
|
||||
# Get the columns used for model inputs
|
||||
model_input_cols = set(training_data.X.data.columns)
|
||||
|
||||
# These columns should NOT be in model inputs
|
||||
assert "geometry" not in model_input_cols
|
||||
assert sample_ensemble.covcol not in model_input_cols
|
||||
assert sample_ensemble.taskcol("binary") not in model_input_cols
|
||||
|
||||
# No darts_* columns should be in model inputs
|
||||
for col in model_input_cols:
|
||||
assert not col.startswith("darts_"), f"Column {col} should have been dropped from model inputs"
|
||||
|
||||
def test_training_dataset_feature_columns_mllabels(self, sample_ensemble_mllabels):
|
||||
"""Test feature columns for mllabels target."""
|
||||
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu")
|
||||
|
||||
model_input_cols = set(training_data.X.data.columns)
|
||||
|
||||
# These columns should NOT be in model inputs
|
||||
assert "geometry" not in model_input_cols
|
||||
assert sample_ensemble_mllabels.covcol not in model_input_cols
|
||||
assert sample_ensemble_mllabels.taskcol("binary") not in model_input_cols
|
||||
|
||||
# No dartsml_* columns should be in model inputs
|
||||
for col in model_input_cols:
|
||||
assert not col.startswith("dartsml_"), f"Column {col} should have been dropped from model inputs"
|
||||
|
||||
def test_inference_vs_training_feature_consistency(self, sample_ensemble):
|
||||
"""Test that inference batches have the same features as training data after column dropping.
|
||||
|
||||
This test simulates the workflow in training.py and inference.py to ensure
|
||||
feature consistency between training and inference.
|
||||
"""
|
||||
# Step 1: Create training dataset (as in training.py)
|
||||
# Use only a small subset by creating just one batch
|
||||
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||||
training_feature_cols = set(training_data.X.data.columns)
|
||||
|
||||
# Step 2: Create inference batch (as in inference.py)
|
||||
# Get just the first batch to speed up test
|
||||
batch_generator = sample_ensemble.create_batches(batch_size=100, cache_mode="n")
|
||||
batch = next(batch_generator, None)
|
||||
|
||||
if batch is None:
|
||||
pytest.skip("No batches created (dataset might be empty)")
|
||||
|
||||
# Simulate the column dropping in predict_proba
|
||||
cols_to_drop = ["geometry"]
|
||||
if sample_ensemble.target == "darts_mllabels":
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
||||
else:
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
||||
|
||||
inference_batch = batch.drop(columns=cols_to_drop)
|
||||
inference_feature_cols = set(inference_batch.columns)
|
||||
|
||||
# The features should match!
|
||||
assert training_feature_cols == inference_feature_cols, (
|
||||
f"Feature mismatch between training and inference!\n"
|
||||
f"Only in training: {training_feature_cols - inference_feature_cols}\n"
|
||||
f"Only in inference: {inference_feature_cols - training_feature_cols}\n"
|
||||
f"Training features ({len(training_feature_cols)}): {sorted(training_feature_cols)}\n"
|
||||
f"Inference features ({len(inference_feature_cols)}): {sorted(inference_feature_cols)}"
|
||||
)
|
||||
|
||||
def test_inference_vs_training_feature_consistency_mllabels(self, sample_ensemble_mllabels):
|
||||
"""Test feature consistency for mllabels target."""
|
||||
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu")
|
||||
training_feature_cols = set(training_data.X.data.columns)
|
||||
|
||||
# Get just the first batch to speed up test
|
||||
batch_generator = sample_ensemble_mllabels.create_batches(batch_size=100, cache_mode="n")
|
||||
batch = next(batch_generator, None)
|
||||
|
||||
if batch is None:
|
||||
pytest.skip("No batches created (dataset might be empty)")
|
||||
|
||||
# Simulate the column dropping in predict_proba
|
||||
cols_to_drop = ["geometry"]
|
||||
if sample_ensemble_mllabels.target == "darts_mllabels":
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
||||
else:
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
||||
|
||||
inference_batch = batch.drop(columns=cols_to_drop)
|
||||
inference_feature_cols = set(inference_batch.columns)
|
||||
|
||||
assert training_feature_cols == inference_feature_cols, (
|
||||
f"Feature mismatch between training and inference!\n"
|
||||
f"Only in training: {training_feature_cols - inference_feature_cols}\n"
|
||||
f"Only in inference: {inference_feature_cols - training_feature_cols}"
|
||||
)
|
||||
|
||||
def test_all_tasks_feature_consistency(self, sample_ensemble):
|
||||
"""Test that all task types produce consistent features."""
|
||||
tasks = ["binary", "count", "density"]
|
||||
feature_sets = {}
|
||||
|
||||
for task in tasks:
|
||||
training_data = sample_ensemble.create_cat_training_dataset(task=task, device="cpu")
|
||||
feature_sets[task] = set(training_data.X.data.columns)
|
||||
|
||||
# All tasks should have the same features
|
||||
binary_features = feature_sets["binary"]
|
||||
for task, features in feature_sets.items():
|
||||
assert features == binary_features, (
|
||||
f"Task '{task}' has different features than 'binary'.\n"
|
||||
f"Difference: {features.symmetric_difference(binary_features)}"
|
||||
)
|
||||
|
||||
def test_training_dataset_shapes(self, sample_ensemble):
|
||||
"""Test that training dataset has correct shapes."""
|
||||
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||||
|
||||
n_features = len(training_data.X.data.columns)
|
||||
n_samples_train = training_data.X.train.shape[0]
|
||||
n_samples_test = training_data.X.test.shape[0]
|
||||
|
||||
# Check X shapes
|
||||
assert training_data.X.train.shape == (n_samples_train, n_features)
|
||||
assert training_data.X.test.shape == (n_samples_test, n_features)
|
||||
|
||||
# Check y shapes
|
||||
assert training_data.y.train.shape == (n_samples_train,)
|
||||
assert training_data.y.test.shape == (n_samples_test,)
|
||||
|
||||
# Check that train + test = total samples
|
||||
assert len(training_data.dataset) == n_samples_train + n_samples_test
|
||||
|
||||
def test_no_nan_in_training_features(self, sample_ensemble):
|
||||
"""Test that training features don't contain NaN values."""
|
||||
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||||
|
||||
# Convert to numpy for checking (handles both numpy and cupy arrays)
|
||||
X_train = np.asarray(training_data.X.train)
|
||||
X_test = np.asarray(training_data.X.test)
|
||||
|
||||
assert not np.isnan(X_train).any(), "Training features contain NaN values"
|
||||
assert not np.isnan(X_test).any(), "Test features contain NaN values"
|
||||
|
||||
def test_batch_coverage(self, sample_ensemble):
|
||||
"""Test that batches cover all data without duplication."""
|
||||
full_dataset = sample_ensemble.create(cache_mode="n")
|
||||
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n"))
|
||||
|
||||
if len(batches) == 0:
|
||||
pytest.skip("No batches created (dataset might be empty)")
|
||||
|
||||
# Collect all cell_ids from batches
|
||||
batch_cell_ids = set()
|
||||
for batch in batches:
|
||||
batch_ids = set(batch.index)
|
||||
# Check for duplicates across batches
|
||||
overlap = batch_cell_ids.intersection(batch_ids)
|
||||
assert len(overlap) == 0, f"Found {len(overlap)} duplicate cell_ids across batches"
|
||||
batch_cell_ids.update(batch_ids)
|
||||
|
||||
# Check that all cell_ids from full dataset are in batches
|
||||
full_cell_ids = set(full_dataset.index)
|
||||
assert batch_cell_ids == full_cell_ids, (
|
||||
f"Batch coverage mismatch.\n"
|
||||
f"Missing from batches: {full_cell_ids - batch_cell_ids}\n"
|
||||
f"Extra in batches: {batch_cell_ids - full_cell_ids}"
|
||||
)
|
||||
|
||||
|
||||
class TestDatasetEnsembleEdgeCases:
|
||||
"""Test edge cases and error handling."""
|
||||
|
||||
def test_invalid_task_raises_error(self, sample_ensemble):
|
||||
"""Test that invalid task raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid task"):
|
||||
sample_ensemble.create_cat_training_dataset(task="invalid", device="cpu") # type: ignore
|
||||
|
||||
def test_stats_method(self, sample_ensemble):
|
||||
"""Test that get_stats returns expected structure."""
|
||||
stats = sample_ensemble.get_stats()
|
||||
|
||||
assert "target" in stats
|
||||
assert "num_target_samples" in stats
|
||||
assert "members" in stats
|
||||
assert "total_features" in stats
|
||||
|
||||
# Check that members dict contains info for each member
|
||||
for member in sample_ensemble.members:
|
||||
assert member in stats["members"]
|
||||
assert "variables" in stats["members"][member]
|
||||
assert "num_features" in stats["members"][member]
|
||||
Loading…
Add table
Add a link
Reference in a new issue