diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 291be9f28..be4291687 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,98 +9,95 @@ on: branches: [master] env: - PYTEST_ADDOPTS: "--cov=numpyro --cov-append" + PYTEST_ADDOPTS: "--cov=numpyro --cov-append --cov-report=lcov" jobs: prek: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v6 - name: prek check uses: j178/prek-action@v1 with: extra-args: --all-files --skip ruff --skip ruff-format --skip ty --skip mypy - lint: - runs-on: ubuntu-latest strategy: matrix: python-version: ["3.11", "3.13"] steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: + enable-cache: true python-version: ${{ matrix.python-version }} - name: Install dependencies run: | sudo apt install -y pandoc gsfonts - python -m pip install --upgrade pip - pip install jaxlib - pip install jax - pip install '.[doc,test]' - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -r docs/requirements.txt - pip freeze + uv sync \ + --upgrade \ + --extra cpu \ + --group ci \ + --group docs \ + --group test + uv pip freeze - name: Lint with mypy and ruff run: | - make lint + uv run make lint - name: Build documentation run: | - make docs + uv run make docs - name: Test documentation run: | - make doctest - python -m doctest -v README.md - + uv run make doctest + uv run python -m doctest -v README.md test-modeling: - runs-on: ubuntu-latest needs: [lint, prek] strategy: matrix: python-version: ["3.11", "3.13"] + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: + enable-cache: true python-version: ${{ matrix.python-version }} - name: Install dependencies run: | sudo apt install -y graphviz - python -m pip install --upgrade pip - # Keep track of pyro-api master branch - pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install jaxlib - pip install jax - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -e '.[dev,test]' - pip freeze + uv sync \ + --upgrade \ + --extra cpu \ + --group ci \ + --group dev \ + --group test + uv pip freeze - name: Test with pytest run: | - CI=1 pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/ + CI=1 uv run pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/ - name: Test x64 run: | - JAX_ENABLE_X64=1 pytest -vs test/test_distributions.py -k "powerLaw or Dagum" + JAX_ENABLE_X64=1 uv run pytest -vs test/test_distributions.py -k "powerLaw or Dagum" - name: Test tracer leak if: matrix.python-version == '3.13' env: JAX_CHECK_TRACER_LEAKS: 1 run: | - pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit - pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke - pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run - pytest -vs test/infer/test_mcmc.py::test_model_with_multiple_exec_paths - pytest -vs test/test_distributions.py::test_mean_var -k Gompertz - + uv run pytest -vs \ + test/infer/test_mcmc.py::test_chain_inside_jit \ + test/infer/test_mcmc.py::test_chain_jit_args_smoke \ + test/infer/test_mcmc.py::test_model_with_multiple_exec_paths \ + test/infer/test_mcmc.py::test_reuse_mcmc_run + uv run pytest -vs test/test_distributions.py::test_mean_var -k Gompertz - name: Coveralls if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13' uses: coverallsapp/github-action@v2 @@ -108,52 +105,55 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} parallel: true flag-name: test-modeling - + file: coverage.lcov test-inference: - runs-on: ubuntu-latest needs: [lint, prek] strategy: matrix: python-version: ["3.11", "3.13"] + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: + enable-cache: true python-version: ${{ matrix.python-version }} + - name: Set up Python ${{ matrix.python-version }} + run: uv python install ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - # Keep track of pyro-api master branch - pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install jaxlib - pip install jax - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -e '.[dev,test]' - pip freeze + uv sync \ + --upgrade \ + --extra cpu \ + --group ci \ + --group dev \ + --group test + uv pip freeze - name: Test with pytest run: | - pytest -vs --durations=20 test/infer/test_mcmc.py - pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py --ignore=test/contrib/test_nested_sampling.py - pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py + uv run pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py + uv run pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py --ignore=test/contrib/test_nested_sampling.py + uv run pytest -vs --durations=20 test/infer/test_mcmc.py - name: Test x64 run: | - JAX_ENABLE_X64=1 pytest -vs test/infer/test_mcmc.py -k x64 + JAX_ENABLE_X64=1 uv run pytest -vs test/infer/test_mcmc.py -k x64 - name: Test chains run: | - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap" - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/test_tfp.py -k "chain" - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/stochastic_support/test_dcc.py - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_hmc_gibbs.py -k "chain" + XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/contrib/stochastic_support/test_dcc.py + XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/contrib/test_tfp.py -k "chain" + XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/infer/test_hmc_gibbs.py -k "chain" + XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap" - name: Test custom prng run: | - JAX_ENABLE_CUSTOM_PRNG=1 pytest -vs test/infer/test_mcmc.py + JAX_ENABLE_CUSTOM_PRNG=1 uv run pytest -vs test/infer/test_mcmc.py - name: Test nested sampling run: | - JAX_ENABLE_X64=1 pytest -vs test/contrib/test_nested_sampling.py + JAX_ENABLE_X64=1 uv run pytest -vs test/contrib/test_nested_sampling.py - name: Coveralls if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13' uses: coverallsapp/github-action@v2 @@ -161,33 +161,38 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} parallel: true flag-name: test-inference - + file: coverage.lcov examples: - runs-on: ubuntu-latest needs: [lint, prek] strategy: matrix: python-version: ["3.13"] + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: + enable-cache: true + update-path: true python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install jaxlib - pip install jax - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -e '.[dev,examples,test]' - pip freeze + uv sync \ + --upgrade \ + --extra cpu \ + --group ci \ + --group dev \ + --group examples \ + --group test + uv pip freeze - name: Test with pytest run: | - CI=1 XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs -k test_example + CI=1 XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs -k test_example - name: Coveralls if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13' uses: coverallsapp/github-action@v2 @@ -195,12 +200,12 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} parallel: true flag-name: examples - + file: coverage.lcov finish: - needs: [test-modeling, test-inference, examples] runs-on: ubuntu-latest + if: github.repository == 'pyro-ppl/numpyro' steps: - name: Coveralls finished uses: coverallsapp/github-action@v2 @@ -208,4 +213,3 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} parallel-finished: true carryforward: "test-modeling,test-inference,examples" - diff --git a/.gitignore b/.gitignore index 62b04d0ba..e7d17112d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +uv.lock numpyro.egg-info __pycache__/ .ipynb_checkpoints/ diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 97f2b8fc5..4aa565561 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -4,13 +4,18 @@ build: os: ubuntu-24.04 tools: python: "3.13" + jobs: # https://docs.readthedocs.com/platform/stable/build-customization.html#install-dependencies-with-uv + pre_create_environment: + - asdf plugin add uv + - asdf install uv latest + - asdf global uv latest + create_environment: + - uv venv "${READTHEDOCS_VIRTUALENV_PATH}" + install: + - UV_PROJECT_ENVIRONMENT="${READTHEDOCS_VIRTUALENV_PATH}" uv sync --frozen --extra cpu --group dev --group docs --group test sphinx: configuration: docs/source/conf.py formats: - pdf - -python: - install: - - requirements: docs/requirements.txt diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index 7e877cb1d..000000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,21 +0,0 @@ -flax -funsor -ipython -jax -jaxlib -jaxns==2.6.9 -Jinja2 -matplotlib -multipledispatch -nbsphinx>=0.8.9 -numpy -optax -pillow -pylab-sdk -pyyaml -readthedocs-sphinx-search>=0.3.2 -sphinx>=5 -sphinx-gallery -sphinx_rtd_theme -tfp-nightly -tqdm diff --git a/pyproject.toml b/pyproject.toml index c126c219f..1007cf18a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,23 @@ cpu = ["jax[cpu]>=0.7.0"] cuda12 = ["jax[cuda12]>=0.7.0"] cuda13 = ["jax[cuda13]>=0.7.0"] tpu = ["jax[tpu]>=0.7.0"] + +[project.urls] +Changelog = "https://github.com/pyro-ppl/numpyro/blob/main/CHANGELOG.md" +Discussion = "https://github.com/pyro-ppl/numpyro/discussions" +Homepage = "https://github.com/pyro-ppl/numpyro" +Issues = "https://github.com/pyro-ppl/numpyro/issues" + +[build-system] +requires = ["setuptools>=61", "wheel"] +build-backend = "setuptools.build_meta" + +[dependency-groups] +ci = [ + "coverage>=7.13.5", + "coveralls>=4.1.0", + "funsor", +] dev = [ "equinox", "flax", @@ -50,21 +67,11 @@ dev = [ "requests", # pylab dependency "tfp-nightly", ] -test = [ - "importlib-metadata<5.0", - "mypy>=1.13", - "pyro-api>=0.1.1", - "pytest>=4.1", - "ruff>=0.1.8", - "scikit-learn", - "scipy>=1.9", - "ty>=0.0.4", -] -doc = [ +docs = [ "ipython", # sphinx needs this to render codes "nbsphinx>=0.8.9", "readthedocs-sphinx-search>=0.3.2", - "sphinx_rtd_theme", + "sphinx-rtd-theme", "sphinx-gallery", "sphinx>=5", ] @@ -77,16 +84,16 @@ examples = [ "seaborn", "wordcloud", ] - -[project.urls] -Changelog = "https://github.com/pyro-ppl/numpyro/blob/main/CHANGELOG.md" -Discussion = "https://github.com/pyro-ppl/numpyro/discussions" -Homepage = "https://github.com/pyro-ppl/numpyro" -Issues = "https://github.com/pyro-ppl/numpyro/issues" - -[build-system] -requires = ["setuptools>=61", "wheel"] -build-backend = "setuptools.build_meta" +test = [ + "importlib-metadata<5.0", + "mypy>=1.13", + "pyro-api>=0.1.1", + "pytest>=4.1", + "ruff>=0.1.8", + "scikit-learn", + "scipy>=1.9", + "ty>=0.0.4", +] # NOTE: this can be simplified using src-layout [tool.setuptools.packages.find] @@ -234,3 +241,6 @@ module = [ "numpyro.distributions.transforms", ] ignore_errors = false + +[tool.uv.sources] +funsor = { git = "https://github.com/pyro-ppl/funsor.git" }