Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: Run Tests

on:
push:
branches: ["main"]
pull_request:
branches: ["main"]

jobs:
test:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

# Set up micromamba
- name: Set up micromamba
uses: mamba-org/setup-micromamba@v2
with:
environment-file: rf_diffusion/environment/ci_environment.yml
init-shell: bash
cache-environment: true

- name: Install pytest
shell: micromamba-shell {0}
run: |
python -m pip install pytest

- name: Download weights
run: |
mkdir weights
curl -o weights/train_session2024-07-08_1720455712_BFF_3.00.pt https://files.ipd.uw.edu/pub/2025_RFDpoly/train_session2024-07-08_1720455712_BFF_3.00.pt

- name: Run tests
shell: micromamba-shell {0}
run: |
python -m pytest test/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.utils.nvtx import nvtx_range

from se3_transformer.runtime.utils import degree_to_dim

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.linear import LinearSE3
from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.utils.nvtx import nvtx_range


class AttentionSE3(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import torch.nn as nn
from dgl import DGLGraph
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.utils.nvtx import nvtx_range

from se3_transformer.model.fiber import Fiber
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch
import torch.nn as nn
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.utils.nvtx import nvtx_range

from se3_transformer.model.fiber import Fiber

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# se3_transformer/utils/nvtx.py
from __future__ import annotations

from contextlib import contextmanager
from typing import Iterator

@contextmanager
def nvtx_range(message: str) -> Iterator[None]:
"""
Safe NVTX range context manager.

- If running with CUDA + NVTX support, emits real NVTX ranges.
- Otherwise, becomes a no-op (CPU-only CI, ROCm-only builds, etc).
"""
try:
import torch

if torch.cuda.is_available() and hasattr(torch.cuda, "nvtx"):
try:
from torch.cuda.nvtx import range as _nvtx_range
with _nvtx_range(message):
yield
return
except Exception:
# CUDA available but NVTX missing/misconfigured -> fall back to no-op
pass

yield
except Exception:
# torch not importable or other unexpected env issue -> no-op
yield

242 changes: 242 additions & 0 deletions rf_diffusion/environment/ci_environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
name: RFDpoly_env_ci_test
channels:
- pytorch
- pyg
- dglteam
- conda-forge
- bioconda
- defaults
dependencies:
- python=3.10.8
- pip=23.0.1
- pytorch=1.13.1
- cpuonly
- pyg=2.2.0
- pytorch-scatter=2.1.0
- pytorch-sparse=0.6.16
- pytorch-cluster=1.6.0
- dgl=1.0.1
- _libgcc_mutex=0.1
- _openmp_mutex=4.5
- anyio=3.5.0
- appdirs=1.4.4
- argon2-cffi=21.3.0
- argon2-cffi-bindings=21.2.0
- asttokens=2.0.5
- attrs=22.1.0
- babel=2.11.0
- backcall=0.2.0
- beautifulsoup4=4.11.1
- blas=1.0
- bleach=4.1.0
- bottleneck=1.3.5
- brotli=1.0.9
- brotli-bin=1.0.9
- brotlipy=0.7.0
- bzip2=1.0.8
- ca-certificates=2022.12.7
- cairo=1.16.0
- certifi=2022.12.7
- cffi=1.15.1
- charset-normalizer=2.0.4
- comm=0.1.2
- conda=23.1.0
- conda-content-trust=0.1.3
- conda-package-handling=2.0.2
- conda-package-streaming=0.7.0
- contourpy=1.0.5
- cryptography=39.0.1
- cycler=0.11.0
- dbus=1.13.18
- debugpy=1.5.1
- decorator=5.1.1
- defusedxml=0.7.1
- entrypoints=0.4
- executing=0.8.3
- expat=2.4.9
- flit-core=3.6.0
- fontconfig=2.14.1
- fonttools=4.25.0
- freetype=2.12.1
- giflib=5.2.1
- glib=2.69.1
- gst-plugins-base=1.14.1
- gstreamer=1.14.1
- icu=58.2
- idna=3.4
- intel-openmp=2021.4.0
- ipykernel=6.19.2
- ipython=8.10.0
- ipython_genutils=0.2.0
- jedi=0.18.1
- jinja2=3.1.2
- joblib=1.1.1
- jpeg=9e
- json5=0.9.6
- jsonschema=4.17.3
- jupyter_client=7.4.9
- jupyter_core=5.2.0
- jupyter_server=1.23.4
- jupyterlab=3.5.3
- jupyterlab_pygments=0.1.2
- jupyterlab_server=2.19.0
- kiwisolver=1.4.4
- krb5=1.19.4
- lcms2=2.12
- ld_impl_linux-64=2.38
- lerc=3.0
- libbrotlicommon=1.0.9
- libbrotlidec=1.0.9
- libbrotlienc=1.0.9
- libclang=10.0.1
- libdeflate=1.17
- libedit=3.1.20221030
- libevent=2.1.12
- libffi=3.4.2
- libgcc-ng=12.2.0
- libgfortran-ng=11.2.0
- libgfortran5=11.2.0
- libiconv=1.17
- libllvm10=10.0.1
- libpng=1.6.39
- libpq=12.9
- libsodium=1.0.18
- libstdcxx-ng=11.2.0
- libtiff=4.5.0
- libuuid=1.41.5
- libwebp=1.2.4
- libwebp-base=1.2.4
- libxcb=1.15
- libxkbcommon=1.0.1
- libxml2=2.9.14
- libxslt=1.1.35
- libzlib=1.2.13
- llvm-openmp=15.0.7
- lxml=4.9.1
- lz4-c=1.9.4
- markupsafe=2.1.1
- matplotlib=3.7.0
- matplotlib-base=3.7.0
- matplotlib-inline=0.1.6
- mistune=0.8.4
- mkl=2021.4.0
- mkl-service=2.4.0
- mkl_fft=1.3.1
- mkl_random=1.2.2
- munkres=1.1.4
- nbclassic=0.5.2
- nbclient=0.5.13
- nbconvert=6.5.4
- nbformat=5.7.0
- ncurses=6.4
- nest-asyncio=1.5.6
- networkx=2.8.4
- notebook=6.5.2
- notebook-shim=0.2.2
- nspr=4.33
- nss=3.74
- numexpr=2.8.4
- numpy=1.23.5
- numpy-base=1.23.5
- openbabel=3.1.1
- openssl=1.1.1t
- packaging=22.0
- pandas=1.5.3
- pandocfilters=1.5.0
- parso=0.8.3
- pcre=8.45
- pexpect=4.8.0
- pickleshare=0.7.5
- pillow=9.4.0
- pip=23.0.1
- pixman=0.40.0
- platformdirs=2.5.2
- pluggy=1.0.0
- ply=3.11
- pooch=1.4.0
- prometheus_client=0.14.1
- prompt-toolkit=3.0.36
- psutil=5.9.0
- ptyprocess=0.7.0
- pure_eval=0.2.2
- pycosat=0.6.4
- pycparser=2.21
- pyg=2.2.0
- pygments=2.11.2
- pyopenssl=23.0.0
- pyparsing=3.0.9
- pyqt=5.15.7
- pyrsistent=0.18.0
- pysocks=1.7.1
- python-dateutil=2.8.2
- python-fastjsonschema=2.16.2
- python_abi=3.10
- pytz=2022.7
- pyzmq=23.2.0
- qt-main=5.15.2
- qt-webengine=5.15.9
- qtwebkit=5.212
- readline=8.2
- requests=2.28.1
- ruamel.yaml=0.17.21
- ruamel.yaml.clib=0.2.6
- scikit-learn=1.2.1
- scipy=1.10.0
- seaborn=0.12.2
- send2trash=1.8.0
- setuptools=65.5.0
- sip=6.6.2
- six=1.16.0
- sniffio=1.2.0
- soupsieve=2.3.2.post1
- sqlite=3.40.1
- stack_data=0.2.0
- terminado=0.17.1
- threadpoolctl=2.2.0
- tinycss2=1.2.1
- tk=8.6.12
- toml=0.10.2
- tomli=2.0.1
- toolz=0.12.0
- tornado=6.2
- tqdm=4.64.1
- traitlets=5.7.1
- typing-extensions=4.4.0
- typing_extensions=4.4.0
- tzdata=2022g
- urllib3=1.26.14
- wcwidth=0.2.5
- webencodings=0.5.1
- websocket-client=0.58.0
- wheel=0.37.1
- xz=5.2.10
- zeromq=4.3.4
- zlib=1.2.13
- zstandard=0.19.0
- zstd=1.5.2
- pip:
- antlr4-python3-runtime==4.9.3
- assertpy==1.1
- click==8.1.3
- colorama==0.4.6
- deepdiff==6.2.3
- docker-pycreds==0.4.0
- e3nn==0.5.1
- gitdb==4.0.10
- gitpython==3.1.31
- hydra-core==1.3.2
- mpmath==1.3.0
- omegaconf==2.3.0
- opt-einsum==3.3.0
- opt-einsum-fx==0.1.4
- ordered-set==4.1.0
- orjson==3.8.7
- pathtools==0.1.2
- protobuf==4.22.1
- pyqt5-sip==12.11.0
- pyyaml==6.0
- sentry-sdk==1.16.0
- setproctitle==1.3.2
- smmap==5.0.0
- sympy==1.11.1
- wandb==0.13.11
4 changes: 2 additions & 2 deletions rf_diffusion/environment/macos_environment.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: RFDpoly_env
name: RFDpoly_env_macos
channels:
- pytorch
- conda-forge
Expand All @@ -18,4 +18,4 @@ dependencies:
- pip:
- dgl==1.0.1
- e3nn==0.5.1
- hydra-core==1.3.2
- hydra-core==1.3.2
Loading
Loading