After installing micromamba, install the following packages: Climlab environment ==================== micromamba install python wrf-python matplotlib=3.8.0 salem pandas numpy cartopy -c conda-forge pip install ipython==8.14.0 ipykernel metpy climlab pip install git+https://github.com/cvanelteren/proplot.git@v0.9.91 MMD3201 environment =================== micromamba install python matplotlib=3.8.0 pandas numpy cartopy scipy -c conda-forge pip install ipython==8.14.0 ipykernel pip install git+https://github.com/cvanelteren/proplot.git@v0.9.91 NeuralGCM environment ====================== micromamba install python matplotlib=3.8.0 pandas numpy cartopy scipy -c conda-forge pip install neuralgcm ipython==8.14.0 ipykernel gcsfs pip install git+https://github.com/cvanelteren/proplot.git@v0.9.91 git clone git@github.com:jax-ml/jax.git From jax directory (build takes a while): python build/build.py build --wheels=jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.6.0 --clang_path=/usr/bin/clang --build_cuda_with_clang Then from jax/dist directory: pip install jax_cuda12_pjrt-0.5.3.dev20250319-py3-none-manylinux2014_x86_64.whl --force-reinstall Finally: pip install -U "jax[cuda12]"