█████╗ ███████╗██╗ ██████╗ █████╗ ██████╗ ██████╗ ███████╗██████╗
██╔══██╗██╔════╝██║██╔════╝ ██╔══██╗██╔══██╗██╔══██╗██╔════╝██╔══██╗
███████║███████╗██║██║ ██████ ███████║██║ ██║██║ ██║█████╗ ██████╔╝
██╔══██║╚════██║██║██║ ██╔══██║██║ ██║██║ ██║██╔══╝ ██╔══██╗
██║ ██║███████║██║╚██████╗ ██║ ██║██████╔╝██████╔╝███████╗██║ ██║
╚═╝ ╚═╝╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═╝╚═════╝ ╚═════╝ ╚══════╝╚═╝ ╚═╝
JAX backend for real FPGA that performs 4-lane, 128-bit vector add over i32 tensors stored in DDR3.
Provide an open-source reference stack for rapidly prototyping custom AI hardware, from ML framework down to physical execution
(JAX → StableHLO/MLIR → PJRT → Driver → AMD Artix-7 FPGA, no emulation layers)
Runs on a real FPGA (AMD Artix-7 XC7A50T), so you can observe actual behavior instead of relying on emulators
- End-to-end insights: memory latency, network transport, and compiler/runtime overhead
- Communication via Ethernet instead of RDMA, DDR3 instead of HBM, more accessible to most developers
- pjrt/README.md: the PJRT plugin exposes
asic-adderto JAX through the OpenXLA PJRT C API and wiresdevice_put,device_get, and synchronous@jax.jitexecution into the compiler and driver. - compiler/README.md: the compiler accepts StableHLO as MLIR text or bytecode, validates the restricted add-only program shape, lowers SSA into executable commands, then performs a second runtime planning pass to assign tensor IDs and DDR addresses.
- driver/README.md: the driver is the host runtime that validates compiler output, opens a fresh raw-socket transport, and emits custom Ethernet frames for H2D, ADD, and D2H operations.
- processor/README.md: the FPGA design receives those frames, stores payloads in DDR3 through MIG, executes the add kernel over DDR-resident tensors, and packetizes results back to the host.
Low-level implementation details are in the READMEs of each component.
- CMake 3.16+
- a C++17 compiler
- Python3.12 env with
jax[cpu]==0.9.0.1installed
Install packages:
libmlir-18-dev
mlir-18-tools
Install MLIR/LLVM 18 development packages:
wget https://apt.llvm.org/llvm.sh
chmod +x llvm.sh
sudo ./llvm.sh 18
Build the compiler first, then the driver, then PJRT.
cmake -S compiler -B compiler/build
cmake --build compiler/build -j
cmake -S driver -B driver/build
cmake --build driver/build -j
cmake -S pjrt -B pjrt/build
cmake --build pjrt/build -jPython scripts in examples/ use JAX to verify Host<-->Device transfers and @jax.jit execution of add-only kernels on the FPGA.
For the raw Ethernet runtime, you need to give the Python interpreter capability to use raw sockets. You can do this using the setcap command:
PY=$(readlink -f "$(command -v python3.12)") && \
sudo setcap cap_net_raw+ep "$PY" && \
getcap "$PY" && \
python host/test_pjrt.py
# Then run the example
python examples/full.py