Refactor xtask to use tracel-xtask and refactor CI workflow (#2063)

* Migrate to xtask-common crate

* Fix example crate name for simple-regression

* Refactor CI workflows

* Flatten linux workflows

* Install grcov and typos from binaries

Although xtask-common support auto-installation of these tools via cargo
it is a lot faster to install them via the distributed binaries

* [CI] Update Rust caches on failure

* [CI] Add shell bash to jobs steps

* [CI] Try cache all crates

* Fix no-std tests not executing

* [CI] Add CARGO_INCREMENTAL 0

* Exclude tch and cuda from tests and merge crates and examples steps

* Fix some typos found with typos cli

* Add Windows and MacOS jobs

* Only test no-std with default rust target

* Fix syntax in composite action setup-windows

* Enable incremental build

* Upate cargo alias for xtask

* Bump to github action checkout v4

* Revert to tch 0.15 and disable WGPU on windows

* Fix color in output

* Add Test command

* Test long output errorring

* Build and test workspace before additional builds and tests

* Disable wgpu tests on windows

* Remove tests- prefix in CI workflow jobs name

* Add Checks command

* Rename ci workflow jobs

* Execute windows and macos CI tests on rust stable only

* Rename integration test files with a test_ prefix

* Fix format

* Don't auto-correct "arange" with typos

* Fix typos in code

* Merge unit and integration tests steps

* Fix macos tests

* Fix coverage step

* Name publish-crate workflow

* Fix bad cache name for macos

* Reorganize commands and get rid of the ci command

* Fix dispatch to customized commands for Burn

* Update to last version of tracel-xtask

* Remove unnecessary shell bash in ci workflow

* Update cargo.lock

* Fix format

* Bump tracel-xtask

* Simplify dispatch of base commands using updated macro

* Update to last version of tracel-xtask

* Adapt legacy run_checks script with new xtask commands

* Run xtask in debug for faster compilation time

* Ditch build step in ci and enable coverage for stable linux only

* Freeze tracel-xtask to specific commit rev

* Update cargo.lock

* Update Step 6 of CONTRIBUTING guidelines about run-checks script

* Remove unneeded CI and CD paragraphgs in CONRIBUTING.md

* Change cache version

* Fix typos

* Use centralized actions and workflows

* Update to last version of tracel-xtask

* Update CONTRIBUTING file to mention integration tests

* Add custom build for thumbv6m-none-eabi

* Ignore onnx files for typos check

* Fix action and workflow paths in github workflows

* Fix custom builds on MacOS

* Bump tracel-xtask crate to last version

* Update Cargo.lock

* Update publish workflow to use reusable workflow in tracel repo

* Add --ci flag for build and test commands
This commit is contained in:
Sylvain Benner 2024-08-28 15:57:13 -04:00 committed by GitHub
parent 40d321cc0d
commit a88c69af4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 1048 additions and 2106 deletions

View File

@ -1,2 +1,2 @@
[alias]
xtask = "run --manifest-path ./xtask/Cargo.toml --"
xtask = "run --target-dir target/xtask --color always --package xtask --bin xtask --"

View File

@ -1,14 +0,0 @@
name: "Install llvmpipe and lavapipe"
description: "Installs software only Vulkan driver"
runs:
using: "composite"
steps:
- name: Install llvmpipe and lavapipe
shell: bash
run: |
sudo apt-get update -y -qq
for i in {1..5}; do
sudo add-apt-repository ppa:kisak/kisak-mesa -y && break || sleep 5;
done
sudo apt-get update
sudo apt install -y libegl1-mesa libgl1-mesa-dri libxcb-xfixes0-dev mesa-vulkan-drivers

247
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,247 @@
name: CI
on:
push:
branches:
- main
paths:
- 'Cargo.lock'
- '**.rs'
- '**.sh'
- '**.ps1'
- '**.yml'
- '**.toml'
- '!**.md'
- '!LICENSE-APACHE'
- '!LICENSE-MIT'
pull_request:
types: [opened, synchronize]
paths:
- 'Cargo.lock'
- '**.rs'
- '**.sh'
- '**.ps1'
- '**.yml'
- '**.toml'
- '!**.md'
- '!LICENSE-APACHE'
- '!LICENSE-MIT'
env:
# Note: It is not possible to define env vars in composite actions.
# To work around this issue we use inputs and define all the env vars here.
# Cargo
CARGO_TERM_COLOR: "always"
# Dependency versioning
# from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml
# Sourced from https://vulkan.lunarg.com/sdk/home#linux
VULKAN_SDK_VERSION: "1.3.268"
# Sourced from https://archive.mesa3d.org/. Bumping this requires
# updating the mesa build in https://github.com/gfx-rs/ci-build and creating a new release.
MESA_VERSION: "23.3.1"
# Corresponds to https://github.com/gfx-rs/ci-build/releases
MESA_CI_BINARY_BUILD: "build18"
# Sourced from https://www.nuget.org/packages/Microsoft.Direct3D.WARP
WARP_VERSION: "1.0.8"
# Sourced from https://github.com/microsoft/DirectXShaderCompiler/releases
# Must also be changed in shaders.yaml
DXC_RELEASE: "v1.7.2308"
DXC_FILENAME: "dxc_2023_08_14.zip"
# Mozilla Grcov
GRCOV_LINK: "https://github.com/mozilla/grcov/releases/download"
GRCOV_VERSION: "0.8.19"
# Typos version
TYPOS_LINK: "https://github.com/crate-ci/typos/releases/download"
TYPOS_VERSION: "1.23.4"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
code-quality:
runs-on: ubuntu-22.04
strategy:
matrix:
rust: [stable]
include:
- rust: stable
cache-version: stable
steps:
- name: Setup Rust
uses: tracel-ai/github-actions/setup-rust@v1
with:
rust-toolchain: ${{ matrix.rust }}
cache-key: ${{ matrix.cache-version }}-linux
# --------------------------------------------------------------------------------
- name: Audit
run: cargo xtask check audit
# --------------------------------------------------------------------------------
- name: Format
shell: bash
env:
# work around for colors
# see: https://github.com/rust-lang/rustfmt/issues/3385
TERM: xterm-256color
run: cargo xtask check format
# --------------------------------------------------------------------------------
- name: Lint
run: cargo xtask check lint
# --------------------------------------------------------------------------------
- name: Typos
uses: tracel-ai/github-actions/check-typos@v1
documentation:
runs-on: ubuntu-22.04
strategy:
matrix:
rust: [stable]
include:
- rust: stable
cache-version: stable
steps:
- name: Setup Rust
uses: tracel-ai/github-actions/setup-rust@v1
with:
rust-toolchain: ${{ matrix.rust }}
cache-key: ${{ matrix.cache-version }}-linux
# --------------------------------------------------------------------------------
- name: Documentation Build
run: cargo xtask doc build
# --------------------------------------------------------------------------------
- name: Documentation Tests
run: cargo xtask doc tests
linux-std-tests:
runs-on: ubuntu-22.04
strategy:
matrix:
rust: [stable, 1.79.0]
include:
- rust: stable
cache-version: stable
coverage: --enable-coverage
- rust: 1.79.0
cache-version: 1-79-0
steps:
- name: Setup Rust
uses: tracel-ai/github-actions/setup-rust@v1
with:
rust-toolchain: ${{ matrix.rust }}
cache-key: ${{ matrix.cache-version }}-linux
# --------------------------------------------------------------------------------
- name: Setup Linux runner
uses: tracel-ai/github-actions/setup-linux@v1
with:
vulkan-sdk-version: ${{ env.VULKAN_SDK_VERSION }}
mesa-version: ${{ env.MESA_VERSION }}
mesa-ci-build-version: ${{ env.MESA_CI_BINARY_BUILD }}
# --------------------------------------------------------------------------------
- name: Install grcov
if: matrix.rust == 'stable'
shell: bash
run: |
curl -L "$GRCOV_LINK/v$GRCOV_VERSION/grcov-x86_64-unknown-linux-musl.tar.bz2" |
tar xj -C $HOME/.cargo/bin
cargo xtask coverage install
# --------------------------------------------------------------------------------
- name: Tests
run: cargo xtask ${{ matrix.coverage }} test --ci
# --------------------------------------------------------------------------------
- name: Generate lcov.info
if: matrix.rust == 'stable'
# /* is to exclude std library code coverage from analysis
run: cargo xtask coverage generate --ignore "/*,xtask/*,examples/*"
# --------------------------------------------------------------------------------
- name: Codecov upload lcov.info
if: matrix.rust == 'stable'
uses: codecov/codecov-action@v4
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
linux-no-std-tests:
runs-on: ubuntu-22.04
strategy:
matrix:
rust: [stable, 1.79.0]
include:
- rust: stable
cache-version: stable
- rust: 1.79.0
cache-version: 1-79-0
steps:
- name: Setup Rust
uses: tracel-ai/github-actions/setup-rust@v1
with:
rust-toolchain: ${{ matrix.rust }}
cache-key: ${{ matrix.cache-version }}-linux-no-std
# --------------------------------------------------------------------------------
- name: Setup Linux runner
uses: tracel-ai/github-actions/setup-linux@v1
with:
vulkan-sdk-version: ${{ env.VULKAN_SDK_VERSION }}
mesa-version: ${{ env.MESA_VERSION }}
mesa-ci-build-version: ${{ env.MESA_CI_BINARY_BUILD }}
# --------------------------------------------------------------------------------
- name: Crates Build
run: cargo xtask --execution-environment no-std build --ci
# --------------------------------------------------------------------------------
- name: Crates Tests
run: cargo xtask --execution-environment no-std test --ci
windows-std-tests:
runs-on: windows-2022
env:
DISABLE_WGPU: '1'
# Keep the stragegy to be able to easily add new rust versions if required
strategy:
matrix:
rust: [stable]
include:
- rust: stable
cache-version: stable
steps:
- name: Setup Rust
uses: tracel-ai/github-actions/setup-rust@v1
with:
rust-toolchain: ${{ matrix.rust }}
cache-key: ${{ matrix.cache-version }}-windows
# --------------------------------------------------------------------------------
- name: Setup Windows runner
if: env.DISABLE_WGPU != '1'
uses: tracel-ai/github-actions/setup-windows@v1
with:
dxc-release: ${{ env.DXC_RELEASE }}
dxc-filename: ${{ env.DXC_FILENAME }}
mesa-version: ${{ env.MESA_VERSION }}
warp-version: ${{ env.WARP_VERSION }}
# --------------------------------------------------------------------------------
- name: Tests
run: cargo xtask test --ci
macos-std-tests:
runs-on: blaze/macos-14
# Keep the stragegy to be able to easily add new rust versions if required
strategy:
matrix:
rust: [stable]
include:
- rust: stable
cache-version: stable
steps:
- name: Setup Rust
uses: tracel-ai/github-actions/setup-rust@v1
with:
rust-toolchain: ${{ matrix.rust }}
cache-key: ${{ matrix.cache-version }}-macos
# --------------------------------------------------------------------------------
- name: Tests
run: cargo xtask test --ci

View File

@ -1,24 +0,0 @@
on:
workflow_call:
inputs:
crate:
required: true
type: string
secrets:
CRATES_IO_API_TOKEN:
required: true
jobs:
publish-crate:
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v3
- name: install rust
uses: dtolnay/rust-toolchain@stable
- name: publish to crates.io
run: cargo xtask publish ${{ inputs.crate }}
env:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

View File

@ -7,51 +7,57 @@ on:
jobs:
publish-burn-derive:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
with:
crate: burn-derive
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-dataset:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
with:
crate: burn-dataset
needs:
- publish-burn-common
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-common:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
with:
crate: burn-common
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-tensor-testgen:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
with:
crate: burn-tensor-testgen
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-tensor:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-tensor-testgen
- publish-burn-common
with:
crate: burn-tensor
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-fusion:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-tensor
- publish-burn-common
with:
crate: burn-fusion
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-jit:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-common
- publish-burn-fusion
@ -59,39 +65,43 @@ jobs:
- publish-burn-ndarray
with:
crate: burn-jit
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-autodiff:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-tensor
- publish-burn-tensor-testgen
- publish-burn-common
with:
crate: burn-autodiff
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-tch:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-tensor
- publish-burn-autodiff
with:
crate: burn-tch
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-ndarray:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-tensor
- publish-burn-autodiff
- publish-burn-common
with:
crate: burn-ndarray
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-wgpu:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-tensor
- publish-burn-autodiff
@ -100,10 +110,11 @@ jobs:
- publish-burn-jit
with:
crate: burn-wgpu
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-cuda:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-tensor
- publish-burn-autodiff
@ -112,20 +123,22 @@ jobs:
- publish-burn-jit
with:
crate: burn-cuda
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-candle:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-tensor
- publish-burn-autodiff
- publish-burn-tch
with:
crate: burn-candle
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-core:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-dataset
- publish-burn-common
@ -138,35 +151,40 @@ jobs:
- publish-burn-candle
with:
crate: burn-core
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-train:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-core
with:
crate: burn-train
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-core
- publish-burn-train
with:
crate: burn
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-import:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn
with:
crate: burn-import
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-onnx-ir:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
with:
crate: onnx-ir
secrets: inherit
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

View File

@ -1,283 +0,0 @@
name: test
on:
push:
branches:
- main
paths:
- 'Cargo.lock'
- '**.rs'
- '**.sh'
- '**.ps1'
- '**.yml'
- '**.toml'
- '!**.md'
- '!LICENSE-APACHE'
- '!LICENSE-MIT'
pull_request:
types: [opened, synchronize]
paths:
- 'Cargo.lock'
- '**.rs'
- '**.sh'
- '**.ps1'
- '**.yml'
- '**.toml'
- '!**.md'
- '!LICENSE-APACHE'
- '!LICENSE-MIT'
env:
#
# Dependency versioning
# from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml
#
# Sourced from https://vulkan.lunarg.com/sdk/home#linux
VULKAN_SDK_VERSION: "1.3.268"
# Sourced from https://www.nuget.org/packages/Microsoft.Direct3D.WARP
WARP_VERSION: "1.0.8"
# Sourced from https://github.com/microsoft/DirectXShaderCompiler/releases
#
# Must also be changed in shaders.yaml
DXC_RELEASE: "v1.7.2308"
DXC_FILENAME: "dxc_2023_08_14.zip"
# Sourced from https://archive.mesa3d.org/. Bumping this requires
# updating the mesa build in https://github.com/gfx-rs/ci-build and creating a new release.
MESA_VERSION: "23.3.1"
# Corresponds to https://github.com/gfx-rs/ci-build/releases
CI_BINARY_BUILD: "build18"
# Typos version
TYPOS_VERSION: "1.16.20"
# Grcov version
GRCOV_VERSION: "0.8.18"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
tests:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [blaze/macos-14, ubuntu-22.04, windows-2022]
# We support both the latest Rust toolchain and the preceding version.
rust: [stable, 1.79.0]
test: ['std', 'no-std', 'examples']
include:
- cache: stable
rust: stable
- cache: 1-79-0
rust: 1.79.0
- os: ubuntu-22.04
coverage-flags: COVERAGE=1
rust: stable
test: std
- os: blaze/macos-14
rust: stable
test: std
- os: windows-2022
wgpu-flags: "DISABLE_WGPU=1"
# not used yet, as wgpu tests are disabled on windows for now
# see issue: https://github.com/tracel-ai/burn/issues/1062
# auto-graphics-backend-flags: "AUTO_GRAPHICS_BACKEND=dx12";'
exclude:
# only need to check this once
- rust: 1.79.0
test: 'examples'
# Do not run no-std tests on macos
- os: blaze/macos-14
test: 'no-std'
# Do not run no-std tests on Windows
- os: windows-2022
test: 'no-std'
steps:
- name: checkout
uses: actions/checkout@v4
- name: install rust
uses: dtolnay/rust-toolchain@master
with:
components: rustfmt, clippy
toolchain: ${{ matrix.rust }}
- name: caching
uses: Swatinem/rust-cache@v2
with:
key: ${{ runner.os }}-${{ matrix.cache }}-${{ matrix.test}}-${{ hashFiles('**/Cargo.toml') }}
prefix-key: "v5-rust"
- name: free disk space
if: runner.os == 'Linux'
run: |
df -h
sudo swapoff -a
sudo rm -f /swapfile
sudo apt clean
df -h
cargo clean --package burn-tch
- name: install llvmpipe and lavapipe
if: runner.os == 'Linux'
uses: ./.github/actions/setup-llvmpipe-lavapipe
- name: Run cargo clippy for stable version
if: runner.os == 'Linux' && matrix.rust == 'stable' && matrix.test == 'std'
uses: giraffate/clippy-action@v1
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
# Run clippy for each workspace, targets, and featrues, considering
# warnings as errors
clippy_flags: --all-targets -- -Dwarnings
# Do not filter results
filter_mode: nofilter
# Report clippy annotations as snippets
reporter: github-pr-check
- name: Install grcov
if: runner.os == 'Linux' && matrix.rust == 'stable' && matrix.test == 'std'
env:
GRCOV_LINK: https://github.com/mozilla/grcov/releases/download
run: |
curl -L "$GRCOV_LINK/v$GRCOV_VERSION/grcov-x86_64-unknown-linux-musl.tar.bz2" |
tar xj -C $HOME/.cargo/bin
# -----------------------------------------------------------------------------------
# BEGIN -- Windows steps disabled as long as DISABLE_WGPU=1 (wgpu tests are disabled)
# -----------------------------------------------------------------------------------
# - name: (windows) install dxc
# # from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml
# if: runner.os == 'Windows'
# shell: bash
# run: |
# set -e
# curl.exe -L --retry 5 https://github.com/microsoft/DirectXShaderCompiler/releases/download/$DXC_RELEASE/$DXC_FILENAME -o dxc.zip
# 7z.exe e dxc.zip -odxc bin/x64/{dxc.exe,dxcompiler.dll,dxil.dll}
# # We need to use cygpath to convert PWD to a windows path as we're using bash.
# cygpath --windows "$PWD/dxc" >> "$GITHUB_PATH"
# - name: (windows) install warp
# # from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml
# if: runner.os == 'Windows'
# shell: bash
# run: |
# set -e
# # Make sure dxc is in path.
# dxc --version
# curl.exe -L --retry 5 https://www.nuget.org/api/v2/package/Microsoft.Direct3D.WARP/$WARP_VERSION -o warp.zip
# 7z.exe e warp.zip -owarp build/native/amd64/d3d10warp.dll
# mkdir -p target/llvm-cov-target/debug/deps
# cp -v warp/d3d10warp.dll target/llvm-cov-target/debug/
# cp -v warp/d3d10warp.dll target/llvm-cov-target/debug/deps
# - name: (windows) install mesa
# # from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml
# if: runner.os == 'Windows'
# shell: bash
# run: |
# set -e
# curl.exe -L --retry 5 https://github.com/pal1000/mesa-dist-win/releases/download/$MESA_VERSION/mesa3d-$MESA_VERSION-release-msvc.7z -o mesa.7z
# 7z.exe e mesa.7z -omesa x64/{opengl32.dll,libgallium_wgl.dll,libglapi.dll,vulkan_lvp.dll,lvp_icd.x86_64.json}
# cp -v mesa/* target/llvm-cov-target/debug/
# cp -v mesa/* target/llvm-cov-target/debug/deps
# # We need to use cygpath to convert PWD to a windows path as we're using bash.
# echo "VK_DRIVER_FILES=`cygpath --windows $PWD/mesa/lvp_icd.x86_64.json`" >> "$GITHUB_ENV"
# echo "GALLIUM_DRIVER=llvmpipe" >> "$GITHUB_ENV"
# -----------------------------------------------------------------------------------
# END -- Windows steps disabled as long as DISABLE_WGPU=1 (wgpu tests are disabled)
# -----------------------------------------------------------------------------------
- name: (linux) install vulkan sdk
# from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml
if: runner.os == 'Linux'
shell: bash
run: |
set -e
sudo apt-get update -y -qq
# vulkan sdk
wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add -
sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-$VULKAN_SDK_VERSION-jammy.list https://packages.lunarg.com/vulkan/$VULKAN_SDK_VERSION/lunarg-vulkan-$VULKAN_SDK_VERSION-jammy.list
sudo apt-get update
sudo apt install -y vulkan-sdk
- name: (linux) install mesa
# from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml
if: runner.os == 'Linux'
shell: bash
run: |
set -e
curl -L --retry 5 https://github.com/gfx-rs/ci-build/releases/download/$CI_BINARY_BUILD/mesa-$MESA_VERSION-linux-x86_64.tar.xz -o mesa.tar.xz
mkdir mesa
tar xpf mesa.tar.xz -C mesa
# The ICD provided by the mesa build is hardcoded to the build environment.
#
# We write out our own ICD file to point to the mesa vulkan
cat <<- EOF > icd.json
{
"ICD": {
"api_version": "1.1.255",
"library_path": "$PWD/mesa/lib/x86_64-linux-gnu/libvulkan_lvp.so"
},
"file_format_version": "1.0.0"
}
EOF
echo "VK_DRIVER_FILES=$PWD/icd.json" >> "$GITHUB_ENV"
echo "LD_LIBRARY_PATH=$PWD/mesa/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH" >> "$GITHUB_ENV"
echo "LIBGL_DRIVERS_PATH=$PWD/mesa/lib/x86_64-linux-gnu/dri" >> "$GITHUB_ENV"
- name: run checks & tests
shell: bash
run: ${{ matrix.coverage-flags }} ${{ matrix.wgpu-flags }} cargo xtask run-checks ${{ matrix.test }}
- name: Codecov upload
if: runner.os == 'Linux' && matrix.rust == 'stable' && matrix.test == 'std'
uses: codecov/codecov-action@v4
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
check-typos:
runs-on: ubuntu-22.04
steps:
- name: checkout
uses: actions/checkout@v4
- name: caching
uses: Swatinem/rust-cache@v2
with:
key: ${{ runner.os }}-typos-${{ hashFiles('**/Cargo.toml') }}
prefix-key: "v5-rust"
- name: Install typos
env:
TYPOS_LINK: https://github.com/crate-ci/typos/releases/download
run: |
curl -L "$TYPOS_LINK/v$TYPOS_VERSION/typos-v$TYPOS_VERSION-x86_64-unknown-linux-musl.tar.gz" |
tar xz -C $HOME/.cargo/bin
- name: run spelling checks using typos
run: cargo xtask run-checks typos

View File

@ -43,13 +43,21 @@ your changes easier. You can create a new branch by using the command
Once you have set up your local repository and created a new branch, you can start making changes.
Be sure to follow the coding standards and guidelines used in the rest of the project.
### Step 6: Run the Pre-Pull Request Script
### Step 6: Validate code before opening a Pull Request
Before you open a pull request, please run [`./run-checks.sh all`](/run-checks.sh). This
will ensure that your changes are in line with our project's standards and guidelines. You can run
this script by opening a terminal, navigating to your local project directory, and typing
`./run-checks`.
Note that under the hood `run-checks` runs the `cargo xtask validate` command which is powered by
the [tracel-xtask crate](https://github.com/tracel-ai/xtask). It is recommended to get familiar with
it as it provides a wide variety of commands to help you work with the code base.
If you have an error related to `torch` installation, see [Burn Torch Backend Installation](./crates/burn-tch/README.md#Installation)
Format and lint errors can often be fixed automatically using the command `cargo xtask fix all`.
### Step 7: Submit a Pull Request
After you've made your changes and run the pre-pull request script, you're ready to submit a pull
@ -87,50 +95,6 @@ You may also want to enable debugging by creating a `.vscode/settings.json` file
4. If you're creating a new library or binary, keep in mind to repeat the step 2 to always keep a fresh list of targets.
## Continuous Integration
### Run checks
On Unix systems, run `run-checks.sh` using this command
```
./run-checks.sh environment
```
On Windows systems, run `run-checks.ps1` using this command:
```
run-checks.ps1 environment
```
The `environment` argument can assume **ONLY** the following values:
- `std` to perform checks using `libstd`
- `no-std` to perform checks on an embedded environment using `libcore`
- `typos` to check for typos in the codebase
- `examples` to check the examples compile
If no `environment` value has been passed, run all checks except examples.
If you have an error related to `torch` installation, see [Burn Torch Backend Installation](./crates/burn-tch/README.md#Installation)
## Continuous Deployment
### Publish crates
Compile `scripts/publish.rs` using this command:
```
rustc scripts/publish.rs --crate-type bin --out-dir scripts
```
Run `scripts/publish` using this command
```
./scripts/publish crate_name
```
where `crate_name` is the name of the crate to publish
## Code Guidelines
We believe in clean and efficient code. While we don't enforce strict coding guidelines, we trust
@ -150,6 +114,11 @@ _Think of `expect()` messages as guidelines for future you and other developers.
This approach ensures that `expect()` messages are informative and aligned with the intended
function outcomes, making debugging and maintenance more straightforward for everyone.
### Writing integration tests
[Integration tests](https://doc.rust-lang.org/rust-by-example/testing/integration_testing.html) should be in a directory called `tests`
besides the `src` directory of a crate. Per convention, they must be implemented in files whose name start with the `test_` prefix.
## Others
To bump for the next version, install `cargo-edit` if its not on your system, and use this command:

341
Cargo.lock generated
View File

@ -23,6 +23,12 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "adler2"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
[[package]]
name = "aes"
version = "0.8.4"
@ -193,9 +199,9 @@ checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76"
[[package]]
name = "arrayvec"
version = "0.7.4"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "ash"
@ -297,7 +303,7 @@ dependencies = [
"os_info",
"percent-encoding",
"rand",
"reqwest 0.12.5",
"reqwest 0.12.7",
"rstest",
"serde",
"serde_json",
@ -319,7 +325,7 @@ dependencies = [
"cc",
"cfg-if",
"libc",
"miniz_oxide",
"miniz_oxide 0.7.4",
"object",
"rustc-demangle",
]
@ -403,9 +409,9 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]]
name = "bitstream-io"
version = "2.5.0"
version = "2.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3dcde5f311c85b8ca30c2e4198d4326bc342c76541590106f5fa4a50946ea499"
checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452"
[[package]]
name = "blas-src"
@ -506,7 +512,7 @@ dependencies = [
"getrandom",
"indicatif",
"rayon",
"reqwest 0.12.5",
"reqwest 0.12.7",
"tokio",
"web-time",
]
@ -759,18 +765,18 @@ dependencies = [
[[package]]
name = "bytemuck"
version = "1.16.3"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "102087e286b4677862ea56cf8fc58bb2cdfa8725c40ffb80fe3a008eb7f2fc83"
checksum = "6fd4c6dcc3b0aea2f5c0b4b82c2b15fe39ddbc76041a310848f4706edf76bb31"
dependencies = [
"bytemuck_derive",
]
[[package]]
name = "bytemuck_derive"
version = "1.7.0"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b"
checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26"
dependencies = [
"proc-macro2",
"quote",
@ -896,12 +902,13 @@ dependencies = [
[[package]]
name = "cc"
version = "1.1.10"
version = "1.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9e8aabfac534be767c909e0690571677d49f41bd8465ae876fe043d52ba5292"
checksum = "50d2eb3cd3d1bf4529e31c215ee6f93ec5a3d536d9f578f93d9d33ee19562932"
dependencies = [
"jobserver",
"libc",
"shlex",
]
[[package]]
@ -1068,9 +1075,9 @@ dependencies = [
[[package]]
name = "cmake"
version = "0.1.50"
version = "0.1.51"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130"
checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a"
dependencies = [
"cc",
]
@ -1525,7 +1532,7 @@ version = "0.15.0"
dependencies = [
"burn",
"csv",
"reqwest 0.12.5",
"reqwest 0.12.7",
"serde",
]
@ -1965,7 +1972,7 @@ dependencies = [
"flume",
"half",
"lebe",
"miniz_oxide",
"miniz_oxide 0.7.4",
"rayon-core",
"smallvec",
"zune-inflate",
@ -2034,12 +2041,12 @@ dependencies = [
[[package]]
name = "flate2"
version = "1.0.31"
version = "1.0.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f211bbe8e69bbd0cfdea405084f128ae8b4aaa6b0b522fc8f2b009084797920"
checksum = "9c0596c1eac1f9e04ed902702e9878208b336edc9d6fddc8a48387349bab3666"
dependencies = [
"crc32fast",
"miniz_oxide",
"miniz_oxide 0.8.0",
]
[[package]]
@ -2490,8 +2497,8 @@ dependencies = [
"aho-corasick",
"bstr",
"log",
"regex-automata",
"regex-syntax",
"regex-automata 0.4.7",
"regex-syntax 0.8.4",
]
[[package]]
@ -2608,9 +2615,9 @@ dependencies = [
[[package]]
name = "h2"
version = "0.4.5"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa82e28a107a8cc405f0839610bdc9b15f1e25ec7d696aa5cf173edbcb1486ab"
checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205"
dependencies = [
"atomic-waker",
"bytes",
@ -2878,7 +2885,7 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"h2 0.4.5",
"h2 0.4.6",
"http 1.1.0",
"http-body 1.0.1",
"httparse",
@ -3004,7 +3011,7 @@ dependencies = [
"globset",
"log",
"memchr",
"regex-automata",
"regex-automata 0.4.7",
"same-file",
"walkdir",
"winapi-util",
@ -3244,9 +3251,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]]
name = "libc"
version = "0.2.157"
version = "0.2.158"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "374af5f94e54fa97cf75e945cce8a6b201e88a1a07e688b47dfd2a59c66dbd86"
checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439"
[[package]]
name = "libfuzzer-sys"
@ -3404,6 +3411,15 @@ dependencies = [
"libc",
]
[[package]]
name = "matchers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [
"regex-automata 0.1.10",
]
[[package]]
name = "matrixmultiply"
version = "0.3.9"
@ -3509,6 +3525,15 @@ dependencies = [
"simd-adler32",
]
[[package]]
name = "miniz_oxide"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1"
dependencies = [
"adler2",
]
[[package]]
name = "mio"
version = "0.8.11"
@ -4386,7 +4411,7 @@ dependencies = [
"crc32fast",
"fdeflate",
"flate2",
"miniz_oxide",
"miniz_oxide 0.7.4",
]
[[package]]
@ -4900,9 +4925,9 @@ dependencies = [
[[package]]
name = "protobuf"
version = "3.5.0"
version = "3.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df67496db1a89596beaced1579212e9b7c53c22dca1d9745de00ead76573d514"
checksum = "0bcc343da15609eaecd65f8aa76df8dc4209d325131d8219358c0aaaebab0bf6"
dependencies = [
"bytes",
"once_cell",
@ -4912,9 +4937,9 @@ dependencies = [
[[package]]
name = "protobuf-codegen"
version = "3.5.0"
version = "3.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eab09155fad2d39333d3796f67845d43e29b266eea74f7bc93f153f707f126dc"
checksum = "c4d0cde5642ea4df842b13eb9f59ea6fafa26dcb43e3e1ee49120e9757556189"
dependencies = [
"anyhow",
"once_cell",
@ -4927,9 +4952,9 @@ dependencies = [
[[package]]
name = "protobuf-parse"
version = "3.5.0"
version = "3.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a16027030d4ec33e423385f73bb559821827e9ec18c50e7874e4d6de5a4e96f"
checksum = "1b0e9b447d099ae2c4993c0cbb03c7a9d6c937b17f2d56cfc0b1550e6fcfdb76"
dependencies = [
"anyhow",
"indexmap 2.4.0",
@ -4943,9 +4968,9 @@ dependencies = [
[[package]]
name = "protobuf-support"
version = "3.5.0"
version = "3.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70e2d30ab1878b2e72d1e2fc23ff5517799c9929e2cf81a8516f9f4dcf2b9cf3"
checksum = "f0766e3675a627c327e4b3964582594b0e8741305d628a98a5de75a1d15f99b9"
dependencies = [
"thiserror",
]
@ -4961,9 +4986,9 @@ dependencies = [
[[package]]
name = "pulp"
version = "0.18.21"
version = "0.18.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ec8d02258294f59e4e223b41ad7e81c874aa6b15bc4ced9ba3965826da0eed5"
checksum = "a0a01a0dc67cf4558d279f0c25b0962bd08fc6dec0137699eae304103e882fe6"
dependencies = [
"bytemuck",
"libm",
@ -5261,9 +5286,9 @@ dependencies = [
[[package]]
name = "redox_users"
version = "0.4.5"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891"
checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
dependencies = [
"getrandom",
"libredox",
@ -5278,8 +5303,17 @@ checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-syntax",
"regex-automata 0.4.7",
"regex-syntax 0.8.4",
]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
dependencies = [
"regex-syntax 0.6.29",
]
[[package]]
@ -5290,24 +5324,21 @@ checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
"regex-syntax 0.8.4",
]
[[package]]
name = "regex-syntax"
version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
[[package]]
name = "regression"
version = "0.15.0"
dependencies = [
"burn",
"log",
"serde",
]
[[package]]
name = "relative-path"
version = "1.9.3"
@ -5349,7 +5380,7 @@ dependencies = [
"serde_json",
"serde_urlencoded",
"sync_wrapper 0.1.2",
"system-configuration",
"system-configuration 0.5.1",
"tokio",
"tokio-native-tls",
"tower-service",
@ -5357,14 +5388,14 @@ dependencies = [
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"winreg 0.50.0",
"winreg",
]
[[package]]
name = "reqwest"
version = "0.12.5"
version = "0.12.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37"
checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63"
dependencies = [
"base64 0.22.1",
"bytes",
@ -5372,7 +5403,7 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-util",
"h2 0.4.5",
"h2 0.4.6",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
@ -5393,7 +5424,7 @@ dependencies = [
"serde_json",
"serde_urlencoded",
"sync_wrapper 1.0.1",
"system-configuration",
"system-configuration 0.6.1",
"tokio",
"tokio-native-tls",
"tower-service",
@ -5401,7 +5432,7 @@ dependencies = [
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"winreg 0.52.0",
"windows-registry",
]
[[package]]
@ -5554,9 +5585,9 @@ dependencies = [
[[package]]
name = "rustls-native-certs"
version = "0.7.1"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba"
checksum = "04182dffc9091a404e0fc069ea5cd60e5b866c3adf881eff99a32d048242dffa"
dependencies = [
"openssl-probe",
"rustls-pemfile 2.1.3",
@ -5654,9 +5685,9 @@ dependencies = [
[[package]]
name = "scc"
version = "2.1.14"
version = "2.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79da19444d9da7a9a82b80ecf059eceba6d3129d84a8610fd25ff2364f255466"
checksum = "aeb7ac86243095b70a7920639507b71d51a63390d1ba26c4f60a552fbb914a37"
dependencies = [
"sdd",
]
@ -5865,6 +5896,12 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "shlex"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signal-hook"
version = "0.3.17"
@ -5916,6 +5953,15 @@ version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a"
[[package]]
name = "simple-regression"
version = "0.15.0"
dependencies = [
"burn",
"log",
"serde",
]
[[package]]
name = "siphasher"
version = "0.3.11"
@ -6025,15 +6071,15 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "stacker"
version = "0.1.15"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c886bd4480155fd3ef527d45e9ac8dd7118a898a46530b7b94c3e21866259fce"
checksum = "95a5daa25ea337c85ed954c0496e3bdd2c7308cc3b24cf7b50d04876654c579f"
dependencies = [
"cc",
"cfg-if",
"libc",
"psm",
"winapi",
"windows-sys 0.36.1",
]
[[package]]
@ -6136,6 +6182,9 @@ name = "sync_wrapper"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394"
dependencies = [
"futures-core",
]
[[package]]
name = "synstructure"
@ -6186,7 +6235,18 @@ checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
dependencies = [
"bitflags 1.3.2",
"core-foundation",
"system-configuration-sys",
"system-configuration-sys 0.5.0",
]
[[package]]
name = "system-configuration"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b"
dependencies = [
"bitflags 2.6.0",
"core-foundation",
"system-configuration-sys 0.6.0",
]
[[package]]
@ -6199,6 +6259,16 @@ dependencies = [
"libc",
]
[[package]]
name = "system-configuration-sys"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "system-deps"
version = "6.2.2"
@ -6446,7 +6516,7 @@ dependencies = [
"rayon",
"rayon-cond",
"regex",
"regex-syntax",
"regex-syntax 0.8.4",
"serde",
"serde_json",
"spm_precompiled",
@ -6458,9 +6528,9 @@ dependencies = [
[[package]]
name = "tokio"
version = "1.39.2"
version = "1.39.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1"
checksum = "9babc99b9923bfa4804bd74722ff02c0381021eafa4db9949217e3be8e84fff5"
dependencies = [
"backtrace",
"bytes",
@ -6604,6 +6674,34 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tracel-xtask"
version = "1.0.0"
source = "git+https://github.com/tracel-ai/xtask?rev=921408bc16e74d3ef8ae59356d928fb6706fb8f4#921408bc16e74d3ef8ae59356d928fb6706fb8f4"
dependencies = [
"anyhow",
"clap 4.5.16",
"derive_more",
"env_logger",
"log",
"rand",
"regex",
"serde_json",
"strum",
"tracel-xtask-macros",
"tracing-subscriber",
]
[[package]]
name = "tracel-xtask-macros"
version = "1.0.0"
source = "git+https://github.com/tracel-ai/xtask?rev=921408bc16e74d3ef8ae59356d928fb6706fb8f4#921408bc16e74d3ef8ae59356d928fb6706fb8f4"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.75",
]
[[package]]
name = "tracing"
version = "0.1.40"
@ -6665,10 +6763,14 @@ version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
]
@ -6749,9 +6851,9 @@ checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d"
[[package]]
name = "unicode-xid"
version = "0.2.4"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a"
[[package]]
name = "unicode_categories"
@ -7174,6 +7276,49 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-registry"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
dependencies = [
"windows-result",
"windows-strings",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-result"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-strings"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10"
dependencies = [
"windows-result",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2"
dependencies = [
"windows_aarch64_msvc 0.36.1",
"windows_i686_gnu 0.36.1",
"windows_i686_msvc 0.36.1",
"windows_x86_64_gnu 0.36.1",
"windows_x86_64_msvc 0.36.1",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
@ -7244,6 +7389,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_msvc"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.5"
@ -7256,6 +7407,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_i686_gnu"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6"
[[package]]
name = "windows_i686_gnu"
version = "0.48.5"
@ -7274,6 +7431,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_msvc"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024"
[[package]]
name = "windows_i686_msvc"
version = "0.48.5"
@ -7286,6 +7449,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_x86_64_gnu"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.5"
@ -7310,6 +7479,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_msvc"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.5"
@ -7350,16 +7525,6 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "winreg"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a277a57398d4bfa075df44f501a17cfdf8542d224f0d36095a2adc7aee4ef0a5"
dependencies = [
"cfg-if",
"windows-sys 0.48.0",
]
[[package]]
name = "wrapcenum-derive"
version = "0.4.1"
@ -7414,16 +7579,12 @@ checksum = "539a77ee7c0de333dcc6da69b177380a0b81e0dacfa4f7344c465a36871ee601"
[[package]]
name = "xtask"
version = "0.6.0"
version = "1.0.0"
dependencies = [
"anyhow",
"clap 4.5.16",
"derive_more",
"env_logger",
"log",
"rand",
"rstest",
"serde_json",
"strum",
"tracel-xtask",
]
[[package]]

View File

@ -160,6 +160,9 @@ portable-atomic-util = { version = "0.2.2", features = ["alloc"] }
cubecl = { version="0.2.0", default-features = false }
cubecl-common = { version="0.2.0", default-features = false }
### For xtask crate ###
tracel-xtask = { git = "https://github.com/tracel-ai/xtask", rev = "921408bc16e74d3ef8ae59356d928fb6706fb8f4" }
[profile.dev]
debug = 0 # Speed up compilation time and not necessary.
opt-level = 2

View File

@ -3,6 +3,11 @@ extend-ignore-identifiers-re = ["ratatui", "Ratatui", "NdArray*", "ND"]
[files]
extend-exclude = [
"*.onnx",
"assets/ModuleSerialization.xml",
"examples/image-classification-web/src/model/label.txt",
]
[default.extend-words]
# Don't correct "arange" which is intentional
arange = "arange"

View File

@ -10,7 +10,7 @@ pub trait OutputProcessor: Send + Sync + 'static {
fn process_line(&self, line: &str);
/// To be called to indicate progress has been made
fn progress(&self);
/// To be called whent the processor has finished processing
/// To be called went the processor has finished processing
fn finish(&self);
}

View File

@ -23,10 +23,10 @@ class Model(nn.Module):
# Subtract a scalar constant from a scalar input
d = k - self.b
# Sutract a scalar from a tensor
# Subtract a scalar from a tensor
x = x - d
# Sutract a tensor from a scalar
# Subtract a tensor from a scalar
x = d - x
return x

View File

@ -24,10 +24,10 @@ class Model(nn.Module):
# Subtract a scalar constant from a scalar input
d = k - self.b
# Sutract a scalar from a tensor
# Subtract a scalar from a tensor
x = x - d
# Sutract a tensor from a scalar
# Subtract a tensor from a scalar
x = d - x
return x

View File

@ -2,7 +2,7 @@
authors = ["aasheeshsingh <aasheeshdtu@gmail.com>"]
edition.workspace = true
license.workspace = true
name = "regression"
name = "simple-regression"
publish = false
version.workspace = true

View File

@ -1,5 +1,10 @@
# This script runs all `burn` checks locally. It may take around 15 minutes on
# the first run.
#!/usr/bin/env pwsh
# Exit immediately if a command exits with a non-zero status.
$ErrorActionPreference = "Stop"
# This script runs all `burn` checks locally. It may take around 15 minutes
# on the first run.
#
# Run `run-checks` using this command:
#
@ -7,16 +12,12 @@
#
# where `environment` can assume **ONLY** the following values:
#
# - `std` to perform checks using `libstd`
# - `no-std` to perform checks on an embedded environment using `libcore`
# - `typos` to check for typos in the codebase
# - `examples` to check the examples compile
# If no `environment` value has been passed, run all checks except examples.
# - `std` to perform validation using `libstd`
# - `no-std` to perform validation on an embedded environment using `libcore`
# - `all` to perform both std and no-std validation
#
# If no `environment` value has been passed, default to `all`.
$exec_env = if ($args.Count -ge 1) { $args[0] } else { "all" }
# Exit if any command fails
$ErrorActionPreference = "Stop"
# Run binary passing the first input parameter, who is mandatory.
# If the input parameter is missing or wrong, it will be the `run-checks`
# binary which will be responsible of arising an error.
cargo xtask run-checks $args[0]
# Run the cargo xtask command with the specified environment
cargo xtask --execution-environment $exec_env validate

View File

@ -12,14 +12,11 @@ set -e
#
# where `environment` can assume **ONLY** the following values:
#
# - `std` to perform checks using `libstd`
# - `no-std` to perform checks on an embedded environment using `libcore`
# - `typos` to check for typos in the codebase
# - `examples` to check the examples compile
# - `std` to perform validation using `libstd`
# - `no-std` to perform validation on an embedded environment using `libcore`
# - `all` to perform both std and no-std validation
#
# If no `environment` value has been passed, run all checks except examples.
# If no `environment` value has been passed.
exec_env=${1:-all}
# Run binary passing the first input parameter, who is mandatory.
# If the input parameter is missing or wrong, it will be the `run-checks`
# binary which will be responsible of arising an error.
cargo xtask run-checks $1
cargo xtask --execution-environment "$exec_env" validate

View File

@ -1,19 +1,15 @@
[package]
name = "xtask"
version = "0.6.0"
version = "1.0.0"
edition = "2021"
license = "MIT OR Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.86"
clap = { version = "4.5.16", features = ["derive"] }
derive_more = { version = "0.99.18", features = ["display"], default-features = false }
env_logger = "0.11.5"
log = "0.4.22"
rand = { workspace = true, features = ["std"] }
serde_json = { version = "1" }
log = { workspace = true }
strum = { workspace = true }
tracel-xtask = { workspace = true }
[dev-dependencies]
rstest = { workspace = true }

View File

@ -1,46 +1,36 @@
use std::{collections::HashMap, path::Path, time::Instant};
use std::path::Path;
use clap::{Args, Subcommand};
use derive_more::Display;
use tracel_xtask::prelude::*;
use crate::{
endgroup, group,
logging::init_logger,
utils::{
cargo::ensure_cargo_crate_is_installed, mdbook::run_mdbook_with_path, process::random_port,
time::format_duration, Params,
},
};
#[derive(Args)]
pub(crate) struct BooksArgs {
#[derive(clap::Args)]
pub struct BooksArgs {
#[command(subcommand)]
book: BookKind,
}
#[derive(Subcommand)]
#[derive(clap::Subcommand)]
pub(crate) enum BookKind {
/// Burn Book, a.k.a. the guide, made for the Burn users.
Burn(BookKindArgs),
/// Contributor book, made for people willing to get all the technical understanding and advices to contribute actively to the project.
/// Contributor book, made for people willing to get all the technical understanding and advice to contribute actively to the project.
Contributor(BookKindArgs),
}
#[derive(Args)]
#[derive(clap::Args)]
pub(crate) struct BookKindArgs {
#[command(subcommand)]
command: BookCommand,
command: BookSubCommand,
}
#[derive(Subcommand, Display)]
pub(crate) enum BookCommand {
#[derive(clap::Subcommand, strum::Display)]
pub(crate) enum BookSubCommand {
/// Build the book
Build,
/// Open the book on the specified port or random port and rebuild it automatically upon changes
Open(OpenArgs),
}
#[derive(Args, Display)]
#[derive(clap::Args)]
pub(crate) struct OpenArgs {
/// Specify the port to open the book on (defaults to a random port if not specified)
#[clap(long, default_value_t = random_port())]
@ -55,15 +45,7 @@ pub(crate) struct Book {
impl BooksArgs {
pub(crate) fn parse(&self) -> anyhow::Result<()> {
init_logger().init();
let start = Instant::now();
Book::run(&self.book)?;
let duration = start.elapsed();
info!(
"\x1B[32;1mTime elapsed for the current execution: {}\x1B[0m",
format_duration(&duration)
);
Ok(())
Book::run(&self.book)
}
}
@ -91,37 +73,37 @@ impl Book {
&args.command,
),
};
book.execute(command);
book.execute(command)
}
fn execute(&self, command: &BookSubCommand) -> anyhow::Result<()> {
ensure_cargo_crate_is_installed("mdbook", None, None, false)?;
group!("{}: {}", self.name, command);
match command {
BookSubCommand::Build => self.build(),
BookSubCommand::Open(args) => self.open(args),
}?;
endgroup!();
Ok(())
}
fn execute(&self, command: &BookCommand) {
ensure_cargo_crate_is_installed("mdbook");
group!("{}: {}", self.name, command);
match command {
BookCommand::Build => self.build(),
BookCommand::Open(args) => self.open(args),
};
endgroup!();
}
fn build(&self) {
run_mdbook_with_path(
"build",
Params::from([]),
HashMap::new(),
fn build(&self) -> anyhow::Result<()> {
run_process(
"mdbook",
&vec!["build"],
None,
Some(self.path),
"mdbook should build the book successfully",
);
)
}
fn open(&self, args: &OpenArgs) {
run_mdbook_with_path(
"serve",
Params::from(["--open", "--port", &args.port.to_string()]),
HashMap::new(),
fn open(&self, args: &OpenArgs) -> anyhow::Result<()> {
run_process(
"mdbook",
&vec!["serve", "--open", "--port", &args.port.to_string()],
None,
Some(self.path),
"mdbook should build the book successfully",
);
"mdbook should open the book successfully",
)
}
}

View File

@ -0,0 +1,86 @@
use std::collections::HashMap;
use strum::IntoEnumIterator;
use tracel_xtask::prelude::*;
use crate::{ARM_NO_ATOMIC_PTR_TARGET, ARM_TARGET, NO_STD_CRATES, WASM32_TARGET};
#[macros::extend_command_args(BuildCmdArgs, Target, None)]
pub struct BurnBuildCmdArgs {
/// Build in CI mode which excludes unsupported crates.
#[arg(long)]
pub ci: bool,
}
pub(crate) fn handle_command(
mut args: BurnBuildCmdArgs,
exec_env: ExecutionEnvironment,
) -> anyhow::Result<()> {
match exec_env {
ExecutionEnvironment::NoStd => {
[
"Default",
WASM32_TARGET,
ARM_TARGET,
ARM_NO_ATOMIC_PTR_TARGET,
]
.iter()
.try_for_each(|build_target| {
let mut build_args = vec!["--no-default-features"];
let mut env_vars = HashMap::new();
if *build_target != "Default" {
build_args.extend(vec!["--target", *build_target]);
}
if *build_target == ARM_NO_ATOMIC_PTR_TARGET {
env_vars.insert(
"RUSTFLAGS",
"--cfg portable_atomic_unsafe_assume_single_core",
);
}
helpers::custom_crates_build(
NO_STD_CRATES.to_vec(),
build_args,
Some(env_vars),
None,
&format!("no-std with target {}", *build_target),
)
})?;
Ok(())
}
ExecutionEnvironment::Std => {
if args.ci {
// Exclude crates that are not supported on CI
args.exclude
.extend(vec!["burn-cuda".to_string(), "burn-tch".to_string()]);
if std::env::var("DISABLE_WGPU").is_ok() {
args.exclude.extend(vec!["burn-wgpu".to_string()]);
};
}
// Build workspace
base_commands::build::handle_command(args.try_into().unwrap())?;
// Specific additional commands to test specific features
// burn-dataset
helpers::custom_crates_build(
vec!["burn-dataset"],
vec!["--all-features"],
None,
None,
"std with all features",
)?;
Ok(())
}
ExecutionEnvironment::All => ExecutionEnvironment::iter()
.filter(|env| *env != ExecutionEnvironment::All)
.try_for_each(|env| {
handle_command(
BurnBuildCmdArgs {
target: args.target.clone(),
exclude: args.exclude.clone(),
only: args.only.clone(),
ci: args.ci,
},
env,
)
}),
}
}

23
xtask/src/commands/doc.rs Normal file
View File

@ -0,0 +1,23 @@
use tracel_xtask::prelude::*;
pub(crate) fn handle_command(mut args: DocCmdArgs) -> anyhow::Result<()> {
if args.get_command() == DocSubCommand::Build {
args.exclude.push("burn-cuda".to_string());
}
// Execute documentation command on workspace
base_commands::doc::handle_command(args.clone())?;
// Specific additional commands to build other docs
if args.get_command() == DocSubCommand::Build {
// burn-dataset
helpers::custom_crates_doc_build(
vec!["burn-dataset"],
vec!["--all-features"],
None,
None,
"All features",
)?;
}
Ok(())
}

View File

@ -0,0 +1,5 @@
pub(crate) mod books;
pub(crate) mod build;
pub(crate) mod doc;
pub(crate) mod test;
pub(crate) mod validate;

116
xtask/src/commands/test.rs Normal file
View File

@ -0,0 +1,116 @@
use strum::IntoEnumIterator;
use tracel_xtask::prelude::*;
use crate::NO_STD_CRATES;
#[macros::extend_command_args(TestCmdArgs, Target, TestSubCommand)]
pub struct BurnTestCmdArgs {
/// Test in CI mode which excludes unsupported crates.
#[arg(long)]
pub ci: bool,
}
pub(crate) fn handle_command(
mut args: BurnTestCmdArgs,
exec_env: ExecutionEnvironment,
) -> anyhow::Result<()> {
match exec_env {
ExecutionEnvironment::NoStd => {
["Default"].iter().try_for_each(|test_target| {
let mut test_args = vec!["--no-default-features"];
if *test_target != "Default" {
test_args.extend(vec!["--target", *test_target]);
}
helpers::custom_crates_tests(
NO_STD_CRATES.to_vec(),
test_args,
None,
None,
"no-std",
)
})?;
Ok(())
}
ExecutionEnvironment::Std => {
if args.ci {
// Exclude crates that are not supported on CI
args.exclude
.extend(vec!["burn-cuda".to_string(), "burn-tch".to_string()]);
}
if std::env::var("DISABLE_WGPU").is_ok() {
args.exclude.extend(vec!["burn-wgpu".to_string()]);
};
// test workspace
base_commands::test::handle_command(args.try_into().unwrap())?;
// Specific additional commands to test specific features
// burn-dataset
helpers::custom_crates_tests(
vec!["burn-dataset"],
vec!["--all-features"],
None,
None,
"std all features",
)?;
// burn-core
helpers::custom_crates_tests(
vec!["burn-core"],
vec!["--features", "test-tch,record-item-custom-serde"],
None,
None,
"std with features: test-tch,record-item-custom-serde",
)?;
if std::env::var("DISABLE_WGPU").is_err() {
helpers::custom_crates_tests(
vec!["burn-core"],
vec!["--features", "test-wgpu"],
None,
None,
"std wgpu",
)?;
}
// MacOS specific tests
#[cfg(target_os = "macos")]
{
// burn-candle
helpers::custom_crates_tests(
vec!["burn-candle"],
vec!["--features", "accelerate"],
None,
None,
"std accelerate",
)?;
// burn-ndarray
helpers::custom_crates_tests(
vec!["burn-ndarray"],
vec!["--features", "blas-accelerate"],
None,
None,
"std blas-accelerate",
)?;
}
Ok(())
}
ExecutionEnvironment::All => ExecutionEnvironment::iter()
.filter(|env| *env != ExecutionEnvironment::All)
.try_for_each(|env| {
handle_command(
BurnTestCmdArgs {
command: args.command.clone(),
target: args.target.clone(),
exclude: args.exclude.clone(),
only: args.only.clone(),
threads: args.threads,
jobs: args.jobs,
ci: args.ci,
},
env,
)
}),
}
}

View File

@ -0,0 +1,111 @@
use tracel_xtask::prelude::*;
use crate::commands::{build::BurnBuildCmdArgs, test::BurnTestCmdArgs};
pub fn handle_command(
args: &ValidateCmdArgs,
exec_env: &ExecutionEnvironment,
) -> anyhow::Result<()> {
let target = Target::Workspace;
let exclude = vec![];
let only = vec![];
if *exec_env == ExecutionEnvironment::Std || *exec_env == ExecutionEnvironment::All {
// ==============
// std validation
// ==============
info!("Run validation for std execution environment...");
// checks
[
CheckSubCommand::Audit,
CheckSubCommand::Format,
CheckSubCommand::Lint,
CheckSubCommand::Typos,
]
.iter()
.try_for_each(|c| {
base_commands::check::handle_command(CheckCmdArgs {
target: target.clone(),
exclude: exclude.clone(),
only: only.clone(),
command: Some(c.clone()),
ignore_audit: args.ignore_audit,
})
})?;
// build
super::build::handle_command(
BurnBuildCmdArgs {
target: target.clone(),
exclude: exclude.clone(),
only: only.clone(),
ci: true,
},
ExecutionEnvironment::Std,
)?;
// tests
super::test::handle_command(
BurnTestCmdArgs {
target: target.clone(),
exclude: exclude.clone(),
only: only.clone(),
threads: None,
jobs: None,
command: Some(TestSubCommand::All),
ci: true,
},
ExecutionEnvironment::Std,
)?;
// documentation
[DocSubCommand::Build, DocSubCommand::Tests]
.iter()
.try_for_each(|c| {
super::doc::handle_command(DocCmdArgs {
target: target.clone(),
exclude: exclude.clone(),
only: only.clone(),
command: Some(c.clone()),
})
})?;
}
if *exec_env == ExecutionEnvironment::NoStd || *exec_env == ExecutionEnvironment::All {
// =================
// no-std validation
// =================
info!("Run validation for no-std execution environment...");
#[cfg(target_os = "linux")]
{
// build
super::build::handle_command(
BurnBuildCmdArgs {
target: target.clone(),
exclude: exclude.clone(),
only: only.clone(),
ci: true,
},
ExecutionEnvironment::NoStd,
)?;
// tests
super::test::handle_command(
BurnTestCmdArgs {
target: target.clone(),
exclude: exclude.clone(),
only: only.clone(),
threads: None,
jobs: None,
command: Some(TestSubCommand::All),
ci: true,
},
ExecutionEnvironment::NoStd,
)?;
}
}
Ok(())
}

View File

@ -1,106 +0,0 @@
use std::{collections::HashMap, time::Instant};
use crate::{
endgroup, group,
logging::init_logger,
utils::{
cargo::{ensure_cargo_crate_is_installed, run_cargo},
rustup::is_current_toolchain_nightly,
time::format_duration,
Params,
},
};
#[derive(clap::ValueEnum, Default, Copy, Clone, PartialEq, Eq)]
pub(crate) enum DependencyCheck {
/// Run all dependency checks.
#[default]
All,
/// Perform an audit of all dependencies using the cargo-audit crate `<https://crates.io/crates/cargo-audit>`
Audit,
/// Run cargo-deny check `<https://crates.io/crates/cargo-deny>`
Deny,
/// Run cargo-udeps to find unused dependencies `<https://crates.io/crates/cargo-udeps>`
Unused,
}
impl DependencyCheck {
pub(crate) fn run(&self) -> anyhow::Result<()> {
// Setup logger
init_logger().init();
// Start time measurement
let start = Instant::now();
match self {
Self::Audit => cargo_audit(),
Self::Deny => cargo_deny(),
Self::Unused => cargo_udeps(),
Self::All => {
cargo_audit();
cargo_deny();
cargo_udeps();
}
}
// Stop time measurement
//
// Compute runtime duration
let duration = start.elapsed();
// Print duration
info!(
"\x1B[32;1mTime elapsed for the current execution: {}\x1B[0m",
format_duration(&duration)
);
Ok(())
}
}
/// Run cargo-audit
fn cargo_audit() {
ensure_cargo_crate_is_installed("cargo-audit");
// Run cargo audit
group!("Cargo: run audit checks");
run_cargo(
"audit",
Params::from([]),
HashMap::new(),
"Cargo audit should be installed and it should correctly run",
);
endgroup!();
}
/// Run cargo-deny
fn cargo_deny() {
ensure_cargo_crate_is_installed("cargo-deny");
// Run cargo deny
group!("Cargo: run deny checks");
run_cargo(
"deny",
Params::from(["check"]),
HashMap::new(),
"Cargo deny should be installed and it should correctly run",
);
endgroup!();
}
/// Run cargo-udeps
fn cargo_udeps() {
if is_current_toolchain_nightly() {
ensure_cargo_crate_is_installed("cargo-udeps");
// Run cargo udeps
group!("Cargo: run unused dependencies checks");
run_cargo(
"udeps",
Params::from([]),
HashMap::new(),
"Cargo udeps should be installed and it should correctly run",
);
endgroup!();
} else {
error!(
"You must use 'cargo +nightly' to check for unused dependencies.
Install a nightly toolchain with 'rustup toolchain install nightly'."
)
}
}

View File

@ -1,67 +0,0 @@
use std::io::Write;
/// Initialise and create a `env_logger::Builder` which follows the
/// GitHub Actions logging syntax when running on CI.
pub(crate) fn init_logger() -> env_logger::Builder {
let mut builder = env_logger::Builder::from_default_env();
builder.target(env_logger::Target::Stdout);
// Find and setup the correct log level
builder.filter(None, get_log_level());
builder.write_style(env_logger::WriteStyle::Always);
// Custom Formatter for Github Actions
if std::env::var("CI").is_ok() {
builder.format(|buf, record| match record.level().as_str() {
"DEBUG" => writeln!(buf, "::debug:: {}", record.args()),
"WARN" => writeln!(buf, "::warning:: {}", record.args()),
"ERROR" => {
writeln!(buf, "::error:: {}", record.args())
}
_ => writeln!(buf, "{}", record.args()),
});
}
builder
}
/// Determine the LogLevel for the logger
fn get_log_level() -> log::LevelFilter {
// DEBUG
match std::env::var("DEBUG") {
Ok(_value) => return log::LevelFilter::Debug,
Err(_err) => (),
}
// ACTIONS_RUNNER_DEBUG
match std::env::var("ACTIONS_RUNNER_DEBUG") {
Ok(_value) => return log::LevelFilter::Debug,
Err(_err) => (),
};
log::LevelFilter::Info
}
/// Group Macro
#[macro_export]
macro_rules! group {
// group!()
($($arg:tt)*) => {
let title = format!($($arg)*);
if std::env::var("CI").is_ok() {
log!(log::Level::Info, "::group::{}", title)
} else {
log!(log::Level::Info, "{}", title)
}
};
}
/// End Group Macro
#[macro_export]
macro_rules! endgroup {
// endgroup!()
() => {
if std::env::var("CI").is_ok() {
log!(log::Level::Info, "::endgroup::")
}
};
}

View File

@ -1,61 +1,76 @@
use clap::{Parser, Subcommand};
mod books;
mod dependencies;
mod logging;
mod publish;
mod runchecks;
mod utils;
mod vulnerabilities;
mod commands;
#[macro_use]
extern crate log;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Command,
}
use std::time::Instant;
use tracel_xtask::prelude::*;
#[derive(Subcommand)]
enum Command {
/// Run commands to manage Burn Books
Books(books::BooksArgs),
/// Run the specified dependencies check locally
Dependencies {
/// The dependency check to run
dependency_check: dependencies::DependencyCheck,
},
/// Publish a crate to crates.io
Publish {
/// The name of the crate to publish on crates.io
name: String,
},
/// Run the specified `burn` tests and checks locally.
RunChecks {
/// The environment to run checks against
#[clap(value_enum, default_value_t = runchecks::CheckType::default())]
env: runchecks::CheckType,
},
/// Run the specified vulnerability check locally. These commands must be called with 'cargo +nightly'.
Vulnerabilities {
/// The vulnerability check to run.
/// For the reference visit the page `<https://doc.rust-lang.org/beta/unstable-book/compiler-flags/sanitizer.html>`
vulnerability_check: vulnerabilities::VulnerabilityCheck,
},
// no-std
const WASM32_TARGET: &str = "wasm32-unknown-unknown";
const ARM_TARGET: &str = "thumbv7m-none-eabi";
const ARM_NO_ATOMIC_PTR_TARGET: &str = "thumbv6m-none-eabi";
const NO_STD_CRATES: &[&str] = &[
"burn",
"burn-core",
"burn-common",
"burn-tensor",
"burn-ndarray",
"burn-no-std-tests",
];
#[macros::base_commands(
Bump,
Check,
Compile,
Coverage,
Doc,
Dependencies,
Fix,
Publish,
Validate,
Vulnerabilities
)]
pub enum Command {
/// Run commands to manage Burn Books.
Books(commands::books::BooksArgs),
/// Build Burn in different modes.
Build(commands::build::BurnBuildCmdArgs),
/// Test Burn.
Test(commands::test::BurnTestCmdArgs),
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
let start = Instant::now();
let args = init_xtask::<Command>()?;
if args.execution_environment == ExecutionEnvironment::NoStd {
// Install additional targets for no-std execution environments
rustup_add_target(WASM32_TARGET)?;
rustup_add_target(ARM_TARGET)?;
rustup_add_target(ARM_NO_ATOMIC_PTR_TARGET)?;
}
match args.command {
Command::Books(args) => args.parse(),
Command::Dependencies { dependency_check } => dependency_check.run(),
Command::Publish { name } => publish::run(name),
Command::RunChecks { env } => env.run(),
Command::Vulnerabilities {
vulnerability_check,
} => vulnerability_check.run(),
}
Command::Books(cmd_args) => cmd_args.parse(),
Command::Build(cmd_args) => {
commands::build::handle_command(cmd_args, args.execution_environment)
}
Command::Doc(cmd_args) => commands::doc::handle_command(cmd_args),
Command::Test(cmd_args) => {
commands::test::handle_command(cmd_args, args.execution_environment)
}
Command::Validate(cmd_args) => {
commands::validate::handle_command(&cmd_args, &args.execution_environment)
}
_ => dispatch_base_commands(args),
}?;
let duration = start.elapsed();
info!(
"\x1B[32;1mTime elapsed for the current execution: {}\x1B[0m",
format_duration(&duration)
);
Ok(())
}

View File

@ -1,114 +0,0 @@
//! This script publishes a crate on `crates.io`.
//!
//! To run the script:
//!
//! cargo xtask publish INPUT_CRATE
use std::{collections::HashMap, env, process::Command, str};
use crate::{
endgroup, group,
utils::{cargo::run_cargo, Params},
};
// Crates.io API token
const CRATES_IO_API_TOKEN: &str = "CRATES_IO_API_TOKEN";
// Obtain local crate version
fn local_version(crate_name: &str) -> String {
// Obtain local crate version contained in cargo pkgid data
let cargo_pkgid_output = Command::new("cargo")
.args(["pkgid", "-p", crate_name])
.output()
.expect("Failed to run cargo pkgid");
// Convert cargo pkgid output into a str
let cargo_pkgid_str = str::from_utf8(&cargo_pkgid_output.stdout)
.expect("Failed to convert pkgid output into a str");
// Extract only the local crate version from str
let (_, local_version) = cargo_pkgid_str
.split_once('#')
.expect("Failed to get local crate version");
local_version.trim_end().to_string()
}
// Obtain remote crate version
fn remote_version(crate_name: &str) -> Option<String> {
// Obtain remote crate version contained in cargo search data
let cargo_search_output = Command::new("cargo")
.args(["search", crate_name, "--limit", "1"])
.output()
.expect("Failed to run cargo search");
// Cargo search returns an empty string in case of a crate not present on
// crates.io
if cargo_search_output.stdout.is_empty() {
None
} else {
// Convert cargo search output into a str
let remote_version_str = str::from_utf8(&cargo_search_output.stdout)
.expect("Failed to convert cargo search output into a str");
// Extract only the remote crate version from str
remote_version_str
.split_once('=')
.and_then(|(_, second)| second.trim_start().split_once(' '))
.map(|(s, _)| s.trim_matches('"').to_string())
}
}
fn publish(crate_name: String) {
// Perform dry-run to ensure everything is good for publishing
let dry_run_params = Params::from(["-p", &crate_name, "--dry-run"]);
run_cargo(
"publish",
dry_run_params,
HashMap::new(),
"The cargo publish --dry-run should complete successfully, indicating readiness for actual publication",
);
let crates_io_token =
env::var(CRATES_IO_API_TOKEN).expect("Failed to retrieve the crates.io API token");
let envs = HashMap::from([("CRATES_IO_API_TOKEN", crates_io_token.clone())]);
let publish_params = Params::from(vec!["-p", &crate_name, "--token", &crates_io_token]);
// Actually publish the crate
run_cargo(
"publish",
publish_params,
envs,
"The crate should be successfully published",
);
}
pub(crate) fn run(crate_name: String) -> anyhow::Result<()> {
group!("Publishing {}...\n", crate_name);
// Retrieve local version for crate
let local_version = local_version(&crate_name);
info!("{crate_name} local version: {local_version}");
// Retrieve remote version for crate if it exists
match remote_version(&crate_name) {
Some(remote_version) => {
info!("{crate_name} remote version: {remote_version}\n");
// Early return if we don't need to publish the crate
if local_version == remote_version {
info!("Remote version {remote_version} is up to date, skipping deployment");
return Ok(());
}
}
None => info!("\nFirst time publishing {crate_name} on crates.io!\n"),
}
// Publish the crate
publish(crate_name);
endgroup!();
Ok(())
}

View File

@ -1,427 +0,0 @@
//! This script is run before a PR is created.
//!
//! It is used to check that the code compiles and passes all tests.
//!
//! It is also used to check that the code is formatted correctly and passes clippy.
use std::collections::HashMap;
use std::env;
use std::process::{Command, Stdio};
use std::str;
use std::time::Instant;
use crate::logging::init_logger;
use crate::utils::cargo::{run_cargo, run_cargo_with_path};
use crate::utils::process::{handle_child_process, run_command};
use crate::utils::rustup::{rustup_add_component, rustup_add_target};
use crate::utils::time::format_duration;
use crate::utils::workspace::{get_workspace_members, WorkspaceMemberType};
use crate::utils::Params;
use crate::{endgroup, group};
// Targets constants
const WASM32_TARGET: &str = "wasm32-unknown-unknown";
const ARM_TARGET: &str = "thumbv7m-none-eabi";
const ARM_NO_ATOMIC_PTR_TARGET: &str = "thumbv6m-none-eabi";
#[derive(clap::ValueEnum, Default, Copy, Clone, PartialEq, Eq)]
pub(crate) enum CheckType {
/// Run all checks except examples
#[default]
All,
/// Run `std` environment checks
Std,
/// Run `no-std` environment checks
NoStd,
/// Check for typos
Typos,
/// Test the examples
Examples,
}
impl CheckType {
pub(crate) fn run(&self) -> anyhow::Result<()> {
// Setup logger
init_logger().init();
// Start time measurement
let start = Instant::now();
// The environment can assume ONLY "std", "no_std", "typos", "examples"
//
// Depending on the input argument, the respective environment checks
// are run.
//
// If no `environment` value has been passed, run all checks except examples.
match self {
Self::Std => std_checks(),
Self::NoStd => no_std_checks(),
Self::Typos => check_typos(),
Self::Examples => check_examples(),
Self::All => {
/* Run all checks */
check_typos();
std_checks();
no_std_checks();
}
}
// Stop time measurement
//
// Compute runtime duration
let duration = start.elapsed();
// Print duration
info!(
"\x1B[32;1mTime elapsed for the current execution: {}\x1B[0m",
format_duration(&duration)
);
Ok(())
}
}
/// Run cargo build command
fn cargo_build(params: Params, envs: Option<HashMap<&str, String>>) {
// Run cargo build
run_cargo(
"build",
params + "--color=always",
envs.unwrap_or_default(),
"Failed to run cargo build",
);
}
/// Run cargo install command
fn cargo_install(params: Params) {
// Run cargo install
run_cargo(
"install",
params + "--color=always",
HashMap::new(),
"Failed to run cargo install",
);
}
/// Run cargo test command
fn cargo_test(params: Params) {
// Run cargo test
run_cargo(
"test",
params + "--color=always" + "--" + "--color=always",
HashMap::new(),
"Failed to run cargo test",
);
}
/// Run cargo fmt command
fn cargo_fmt() {
group!("Cargo: fmt");
run_cargo(
"fmt",
["--check", "--all", "--", "--color=always"].into(),
HashMap::new(),
"Failed to run cargo fmt",
);
endgroup!();
}
/// Run cargo clippy command
fn cargo_clippy() {
if std::env::var("CI").is_ok() {
return;
}
// Run cargo clippy
run_cargo(
"clippy",
["--color=always", "--all-targets", "--", "-D", "warnings"].into(),
HashMap::new(),
"Failed to run cargo clippy",
);
}
/// Run cargo doc command
fn cargo_doc(params: Params) {
// Run cargo doc
run_cargo(
"doc",
params + "--color=always",
HashMap::new(),
"Failed to run cargo doc",
);
}
// Build and test a crate in a no_std environment
fn build_and_test_no_std<const N: usize>(crate_name: &str, extra_args: [&str; N]) {
group!("Checks: {} (no-std)", crate_name);
// Run cargo build --no-default-features
cargo_build(
Params::from(["-p", crate_name, "--no-default-features"]) + extra_args,
None,
);
// Run cargo test --no-default-features
cargo_test(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args);
// Run cargo build --no-default-features --target wasm32-unknown-unknowns
cargo_build(
Params::from([
"-p",
crate_name,
"--no-default-features",
"--target",
WASM32_TARGET,
]) + extra_args,
None,
);
// Run cargo build --no-default-features --target thumbv7m-none-eabi
cargo_build(
Params::from([
"-p",
crate_name,
"--no-default-features",
"--target",
ARM_TARGET,
]) + extra_args,
None,
);
// Run cargo build --no-default-features --target thumbv6m-none-eabi
cargo_build(
Params::from([
"-p",
crate_name,
"--no-default-features",
"--target",
ARM_NO_ATOMIC_PTR_TARGET,
]) + extra_args,
Some(HashMap::from([(
"RUSTFLAGS",
"--cfg portable_atomic_unsafe_assume_single_core".to_string(),
)])),
);
endgroup!();
}
// Setup code coverage
fn setup_coverage() {
// Install llvm-tools-preview
rustup_add_component("llvm-tools-preview");
// Set coverage environment variables
env::set_var("RUSTFLAGS", "-Cinstrument-coverage");
env::set_var("LLVM_PROFILE_FILE", "burn-%p-%m.profraw");
}
// Run grcov to produce lcov.info
fn run_grcov() {
// grcov arguments
#[rustfmt::skip]
let args = [
".",
"--binary-path", "./target/debug/",
"-s", ".",
"-t", "lcov",
"--branch",
"--ignore-not-existing",
"--ignore", "/*", // It excludes std library code coverage from analysis
"--ignore", "xtask/*",
"--ignore", "examples/*",
"-o", "lcov.info",
];
run_command(
"grcov",
&args,
"Failed to run grcov",
"Failed to wait for grcov child process",
);
}
// Run no_std checks
fn no_std_checks() {
// Install wasm32 target
rustup_add_target(WASM32_TARGET);
// Install ARM target
rustup_add_target(ARM_TARGET);
// Install ARM no atomic ptr target
rustup_add_target(ARM_NO_ATOMIC_PTR_TARGET);
// Run checks for the following crates
build_and_test_no_std("burn", []);
build_and_test_no_std("burn-core", []);
build_and_test_no_std("burn-common", []);
build_and_test_no_std("burn-tensor", []);
build_and_test_no_std("burn-ndarray", []);
build_and_test_no_std("burn-no-std-tests", []);
}
// Test burn-core with tch and wgpu backend
fn burn_core_std() {
// Run cargo test --features test-tch, record-item-custom-serde
group!("Test: burn-core (tch) and record-item-custom-serde");
cargo_test(
[
"-p",
"burn-core",
"--features",
"test-tch,record-item-custom-serde,",
]
.into(),
);
endgroup!();
// Run cargo test --features test-wgpu
if std::env::var("DISABLE_WGPU").is_err() {
group!("Test: burn-core (wgpu)");
cargo_test(["-p", "burn-core", "--features", "test-wgpu"].into());
endgroup!();
}
}
// Test burn-dataset features
fn burn_dataset_features_std() {
group!("Checks: burn-dataset (all-features)");
// Run cargo build --all-features
cargo_build(["-p", "burn-dataset", "--all-features"].into(), None);
// Run cargo test --all-features
cargo_test(["-p", "burn-dataset", "--all-features"].into());
// Run cargo doc --all-features
cargo_doc(["-p", "burn-dataset", "--all-features", "--no-deps"].into());
endgroup!();
}
// macOS only checks
#[cfg(target_os = "macos")]
fn macos_checks() {
// Leverages the macOS Accelerate framework: https://developer.apple.com/documentation/accelerate
group!("Checks: burn-candle (accelerate)");
cargo_test(["-p", "burn-candle", "--features", "accelerate"].into());
endgroup!();
// Leverages the macOS Accelerate framework: https://developer.apple.com/documentation/accelerate
group!("Checks: burn-ndarray (accelerate)");
cargo_test(["-p", "burn-ndarray", "--features", "blas-accelerate"].into());
endgroup!();
}
fn std_checks() {
// Set RUSTDOCFLAGS environment variable to treat warnings as errors
// for the documentation build
env::set_var("RUSTDOCFLAGS", "-D warnings");
// Check if COVERAGE environment variable is set
let is_coverage = std::env::var("COVERAGE").is_ok();
let disable_wgpu = std::env::var("DISABLE_WGPU").is_ok();
// Check format
cargo_fmt();
// Check clippy lints
cargo_clippy();
// Produce documentation for each workspace member
group!("Docs: crates");
let mut params = Params::from(["--workspace", "--no-deps"]);
// Exclude burn-cuda on all platforms
params.params.push("--exclude".to_string());
params.params.push("burn-cuda".to_string());
cargo_doc(params);
endgroup!();
// Setup code coverage
if is_coverage {
setup_coverage();
}
// Build & test each member in workspace
let members = get_workspace_members(WorkspaceMemberType::Crate);
for member in members {
if disable_wgpu && member.name == "burn-wgpu" {
continue;
}
if member.name == "burn-cuda" {
// burn-cuda requires CUDA Toolkit which is not currently setup on our CI runners
continue;
}
if member.name == "burn-tch" {
continue;
}
group!("Checks: {}", member.name);
cargo_build(Params::from(["-p", &member.name]), None);
cargo_test(Params::from(["-p", &member.name]));
endgroup!();
}
// Test burn-candle with accelerate (macOS only)
#[cfg(target_os = "macos")]
macos_checks();
// Test burn-dataset features
burn_dataset_features_std();
// Test burn-core with tch and wgpu backend
burn_core_std();
// Run grcov and produce lcov.info
if is_coverage {
run_grcov();
}
}
fn check_typos() {
// This path defines where typos-cli is installed on different
// operating systems.
let typos_cli_path = std::env::var("CARGO_HOME")
.map(|v| std::path::Path::new(&v).join("bin/typos-cli"))
.unwrap();
// Do not run cargo install on CI to speed up the computation.
// Check whether the file has been installed on
if std::env::var("CI").is_err() && !typos_cli_path.exists() {
// Install typos-cli
cargo_install(["typos-cli", "--version", "1.16.5"].into());
}
info!("Running typos check \n\n");
// Run typos command as child process
let typos = Command::new("typos")
.args(["--exclude", "**/*.onnx"])
.stdout(Stdio::inherit()) // Send stdout directly to terminal
.stderr(Stdio::inherit()) // Send stderr directly to terminal
.spawn()
.expect("Failed to run typos");
// Handle typos child process
handle_child_process(typos, "Failed to wait for typos child process");
}
fn check_examples() {
let members = get_workspace_members(WorkspaceMemberType::Example);
for member in members {
if member.name == "notebook" {
continue;
}
group!("Checks: Example - {}", member.name);
run_cargo_with_path(
"check",
["--examples"].into(),
HashMap::new(),
Some(member.path),
"Failed to check example",
);
endgroup!();
}
}

View File

@ -1,66 +0,0 @@
use std::{
collections::HashMap,
path::Path,
process::{Command, Stdio},
};
use crate::{endgroup, group, utils::process::handle_child_process};
use super::Params;
/// Run a cargo command
pub(crate) fn run_cargo(command: &str, params: Params, envs: HashMap<&str, String>, error: &str) {
run_cargo_with_path::<String>(command, params, envs, None, error)
}
/// Run a cargo command with the passed directory as the current directory
pub(crate) fn run_cargo_with_path<P: AsRef<Path>>(
command: &str,
params: Params,
envs: HashMap<&str, String>,
path: Option<P>,
error: &str,
) {
info!("cargo {} {}\n", command, params.params.join(" "));
let mut cargo = Command::new("cargo");
cargo
.env("CARGO_INCREMENTAL", "0")
.envs(&envs)
.arg(command)
.args(&params.params)
.stdout(Stdio::inherit()) // Send stdout directly to terminal
.stderr(Stdio::inherit()); // Send stderr directly to terminal
if let Some(path) = path {
cargo.current_dir(path);
}
// Handle cargo child process
let cargo_process = cargo.spawn().expect(error);
handle_child_process(cargo_process, "Cargo process should run flawlessly");
}
/// Ensure that a cargo crate is installed
pub(crate) fn ensure_cargo_crate_is_installed(crate_name: &str) {
if !is_cargo_crate_installed(crate_name) {
group!("Cargo: install crate '{}'", crate_name);
run_cargo(
"install",
[crate_name].into(),
HashMap::new(),
&format!("crate '{}' should be installed", crate_name),
);
endgroup!();
}
}
/// Returns true if the passed cargo crate is installed locally
fn is_cargo_crate_installed(crate_name: &str) -> bool {
let output = Command::new("cargo")
.arg("install")
.arg("--list")
.output()
.expect("Should get the list of installed cargo commands");
let output_str = String::from_utf8_lossy(&output.stdout);
output_str.lines().any(|line| line.contains(crate_name))
}

View File

@ -1,35 +0,0 @@
use std::{
collections::HashMap,
path::Path,
process::{Command, Stdio},
};
use crate::utils::process::handle_child_process;
use super::Params;
/// Run a mdbook command with the passed directory as the current directory
pub(crate) fn run_mdbook_with_path<P: AsRef<Path>>(
command: &str,
params: Params,
envs: HashMap<&str, String>,
path: Option<P>,
error: &str,
) {
info!("mdbook {} {}\n", command, params.params.join(" "));
let mut mdbook = Command::new("mdbook");
mdbook
.envs(&envs)
.arg(command)
.args(&params.params)
.stdout(Stdio::inherit()) // Send stdout directly to terminal
.stderr(Stdio::inherit()); // Send stderr directly to terminal
if let Some(path) = path {
mdbook.current_dir(path);
}
// Handle mdbook child process
let mdbook_process = mdbook.spawn().expect(error);
handle_child_process(mdbook_process, "mdbook process should run flawlessly");
}

View File

@ -1,50 +0,0 @@
pub(crate) mod cargo;
pub(crate) mod mdbook;
pub(crate) mod process;
pub(crate) mod rustup;
pub(crate) mod time;
pub(crate) mod workspace;
pub(crate) struct Params {
pub params: Vec<String>,
}
impl<const N: usize> From<[&str; N]> for Params {
fn from(value: [&str; N]) -> Self {
Self {
params: value.iter().map(|v| v.to_string()).collect(),
}
}
}
impl From<&str> for Params {
fn from(value: &str) -> Self {
Self {
params: vec![value.to_string()],
}
}
}
impl From<Vec<&str>> for Params {
fn from(value: Vec<&str>) -> Self {
Self {
params: value.iter().map(|s| s.to_string()).collect(),
}
}
}
impl std::fmt::Display for Params {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.params.join(" ").as_str())
}
}
impl<Rhs: Into<Params>> std::ops::Add<Rhs> for Params {
type Output = Params;
fn add(mut self, rhs: Rhs) -> Self::Output {
let rhs: Params = rhs.into();
self.params.extend(rhs.params);
self
}
}

View File

@ -1,38 +0,0 @@
use rand::Rng;
use std::process::{Child, Command, Stdio};
/// Handle child process
pub(crate) fn handle_child_process(mut child: Child, error: &str) {
// Wait for the child process to finish
let status = child.wait().expect(error);
// If exit status is not a success, terminate the process with an error
if !status.success() {
// Use the exit code associated to a command to terminate the process,
// if any exit code had been found, use the default value 1
std::process::exit(status.code().unwrap_or(1));
}
}
/// Run a command
pub(crate) fn run_command(command: &str, args: &[&str], command_error: &str, child_error: &str) {
// Format command
info!("{command} {}\n\n", args.join(" "));
// Run command as child process
let command = Command::new(command)
.args(args)
.stdout(Stdio::inherit()) // Send stdout directly to terminal
.stderr(Stdio::inherit()) // Send stderr directly to terminal
.spawn()
.expect(command_error);
// Handle command child process
handle_child_process(command, child_error);
}
/// Return a random port between 3000 and 9999
pub(crate) fn random_port() -> u16 {
let mut rng = rand::thread_rng();
rng.gen_range(3000..=9999)
}

View File

@ -1,68 +0,0 @@
use std::process::{Command, Stdio};
use crate::{endgroup, group, utils::process::handle_child_process};
use super::Params;
/// Run rustup command
pub(crate) fn rustup(command: &str, params: Params, expected: &str) {
info!("rustup {} {}\n", command, params);
// Run rustup
let mut rustup = Command::new("rustup");
rustup
.arg(command)
.args(params.params)
.stdout(Stdio::inherit()) // Send stdout directly to terminal
.stderr(Stdio::inherit()); // Send stderr directly to terminal
let cargo_process = rustup.spawn().expect(expected);
handle_child_process(cargo_process, "Failed to wait for rustup child process");
}
/// Add a Rust target
pub(crate) fn rustup_add_target(target: &str) {
group!("Rustup: add target {}", target);
rustup(
"target",
Params::from(["add", target]),
"Target should be added",
);
endgroup!();
}
/// Add a Rust component
pub(crate) fn rustup_add_component(component: &str) {
group!("Rustup: add component {}", component);
rustup(
"component",
Params::from(["add", component]),
"Component should be added",
);
endgroup!();
}
// Returns the output of the rustup command to get the installed targets
pub(crate) fn rustup_get_installed_targets() -> String {
let output = Command::new("rustup")
.args(["target", "list", "--installed"])
.stdout(Stdio::piped())
.output()
.expect("Rustup command should execute successfully");
String::from_utf8(output.stdout).expect("Output should be valid UTF-8")
}
/// Returns true if the current toolchain is the nightly
pub(crate) fn is_current_toolchain_nightly() -> bool {
let output = Command::new("rustup")
.arg("show")
.output()
.expect("Should get the list of installed Rust toolchains");
let output_str = String::from_utf8_lossy(&output.stdout);
for line in output_str.lines() {
// look for the "rustc.*-nightly" line
if line.contains("rustc") && line.contains("-nightly") {
return true;
}
}
// assume we are using a stable toolchain if we did not find the nightly compiler
false
}

View File

@ -1,15 +0,0 @@
use std::time::Duration;
/// Print duration as HH:MM:SS format
pub(crate) fn format_duration(duration: &Duration) -> String {
let seconds = duration.as_secs();
let minutes = seconds / 60;
let hours = minutes / 60;
let remaining_minutes = minutes % 60;
let remaining_seconds = seconds % 60;
format!(
"{:02}:{:02}:{:02}",
hours, remaining_minutes, remaining_seconds
)
}

View File

@ -1,90 +0,0 @@
use std::{path::Path, process::Command};
use serde_json::Value;
const MEMBER_PATH_PREFIX: &str = if cfg!(target_os = "windows") {
"path+file:///"
} else {
"path+file://"
};
pub(crate) enum WorkspaceMemberType {
Crate,
Example,
}
#[derive(Debug)]
pub(crate) struct WorkspaceMember {
pub(crate) name: String,
pub(crate) path: String,
}
impl WorkspaceMember {
fn new(name: String, path: String) -> Self {
Self { name, path }
}
}
/// Get workspace crates
pub(crate) fn get_workspace_members(w_type: WorkspaceMemberType) -> Vec<WorkspaceMember> {
// Run `cargo metadata` command to get project metadata
let output = Command::new("cargo")
.arg("metadata")
.output()
.expect("Failed to execute command");
// Parse the JSON output
let metadata: Value = serde_json::from_slice(&output.stdout).expect("Failed to parse JSON");
// Extract workspaces from the metadata, excluding examples/ and xtask
let workspaces = metadata["workspace_members"]
.as_array()
.expect("Expected an array of workspace members")
.iter()
.filter_map(|member| {
let member_str = member.as_str()?;
let has_whitespace = member_str.chars().any(|c| c.is_whitespace());
let (name, path) = if has_whitespace {
parse_workspace_member0(member_str)?
} else {
parse_workspace_member1(member_str)?
};
match w_type {
WorkspaceMemberType::Crate if name != "xtask" && !path.contains("examples/") => {
Some(WorkspaceMember::new(name.to_string(), path.to_string()))
}
WorkspaceMemberType::Example if name != "xtask" && path.contains("examples/") => {
Some(WorkspaceMember::new(name.to_string(), path.to_string()))
}
_ => None,
}
})
.collect();
workspaces
}
/// Legacy cargo metadata format for member specs (rust < 1.77)
/// Example:
/// "backend-comparison 0.13.0 (path+file:///Users/username/burn/backend-comparison)"
fn parse_workspace_member0(specs: &str) -> Option<(String, String)> {
let parts: Vec<_> = specs.split_whitespace().collect();
let (name, path) = (parts.first()?.to_owned(), parts.last()?.to_owned());
// skip the first character because it is a '('
let path = path
.chars()
.skip(1)
.collect::<String>()
.replace(MEMBER_PATH_PREFIX, "")
.replace(')', "");
Some((name.to_string(), path.to_string()))
}
/// Cargo metadata format for member specs (rust >= 1.77)
/// Example:
/// "path+file:///Users/username/burn/backend-comparison#0.13.0"
fn parse_workspace_member1(specs: &str) -> Option<(String, String)> {
let no_prefix = specs.replace(MEMBER_PATH_PREFIX, "").replace(')', "");
let path = Path::new(no_prefix.split_once('#')?.0);
let name = path.file_name()?.to_str()?;
let path = path.to_str()?;
Some((name.to_string(), path.to_string()))
}

View File

@ -1,396 +0,0 @@
use std::collections::HashMap;
use std::time::Instant;
use crate::logging::init_logger;
use crate::utils::cargo::{ensure_cargo_crate_is_installed, run_cargo};
use crate::utils::rustup::{
is_current_toolchain_nightly, rustup_add_component, rustup_get_installed_targets,
};
use crate::utils::time::format_duration;
use crate::utils::Params;
use crate::{endgroup, group};
use std::fmt;
#[derive(clap::ValueEnum, Default, Copy, Clone, PartialEq, Eq)]
pub(crate) enum VulnerabilityCheck {
/// Run all most useful vulnerability checks.
#[default]
All,
/// Run Address sanitizer (memory error detector)
AddressSanitizer,
/// Run LLVM Control Flow Integrity (CFI) (provides forward-edge control flow protection)
ControlFlowIntegrity,
/// Run newer variant of Address sanitizer (memory error detector similar to AddressSanitizer, but based on partial hardware assistance)
HWAddressSanitizer,
/// Run Kernel LLVM Control Flow Integrity (KCFI) (provides forward-edge control flow protection for operating systems kernels)
KernelControlFlowIntegrity,
/// Run Leak sanitizer (run-time memory leak detector)
LeakSanitizer,
/// Run memory sanitizer (detector of uninitialized reads)
MemorySanitizer,
/// Run another address sanitizer (like AddressSanitizer and HardwareAddressSanitizer but with lower overhead suitable for use as hardening for production binaries)
MemTagSanitizer,
/// Run nightly-only checks through cargo-careful `<https://crates.io/crates/cargo-careful>`
NightlyChecks,
/// Run SafeStack check (provides backward-edge control flow protection by separating
/// stack into safe and unsafe regions)
SafeStack,
/// Run ShadowCall check (provides backward-edge control flow protection - aarch64 only)
ShadowCallStack,
/// Run Thread sanitizer (data race detector)
ThreadSanitizer,
}
impl VulnerabilityCheck {
pub(crate) fn run(&self) -> anyhow::Result<()> {
// Setup logger
init_logger().init();
// Start time measurement
let start = Instant::now();
match self {
Self::NightlyChecks => cargo_careful(),
Self::AddressSanitizer => Sanitizer::Address.run_tests(),
Self::ControlFlowIntegrity => Sanitizer::CFI.run_tests(),
Self::HWAddressSanitizer => Sanitizer::HWAddress.run_tests(),
Self::KernelControlFlowIntegrity => Sanitizer::KCFI.run_tests(),
Self::LeakSanitizer => Sanitizer::Leak.run_tests(),
Self::MemorySanitizer => Sanitizer::Memory.run_tests(),
Self::MemTagSanitizer => Sanitizer::MemTag.run_tests(),
Self::SafeStack => Sanitizer::SafeStack.run_tests(),
Self::ShadowCallStack => Sanitizer::ShadowCallStack.run_tests(),
Self::ThreadSanitizer => Sanitizer::Thread.run_tests(),
Self::All => {
cargo_careful();
Sanitizer::Address.run_tests();
Sanitizer::Leak.run_tests();
Sanitizer::Memory.run_tests();
Sanitizer::SafeStack.run_tests();
Sanitizer::Thread.run_tests();
}
}
// Stop time measurement
//
// Compute runtime duration
let duration = start.elapsed();
// Print duration
info!(
"\x1B[32;1mTime elapsed for the current execution: {}\x1B[0m",
format_duration(&duration)
);
Ok(())
}
}
/// Run cargo-careful
fn cargo_careful() {
if is_current_toolchain_nightly() {
ensure_cargo_crate_is_installed("cargo-careful");
rustup_add_component("rust-src");
// prepare careful sysroot
group!("Cargo: careful setup");
run_cargo(
"careful",
Params::from(["setup"]),
HashMap::new(),
"Cargo sysroot should be available",
);
endgroup!();
// Run cargo careful
group!("Cargo: run careful checks");
run_cargo(
"careful",
Params::from(["test"]),
HashMap::new(),
"Cargo careful should be installed and it should correctly run",
);
endgroup!();
} else {
error!(
"You must use 'cargo +nightly' to run nightly checks.
Install a nightly toolchain with 'rustup toolchain install nightly'."
)
}
}
// Represents the various sanitizer available in nightly compiler
// source: https://doc.rust-lang.org/beta/unstable-book/compiler-flags/sanitizer.html
#[allow(clippy::upper_case_acronyms)]
enum Sanitizer {
Address,
CFI,
HWAddress,
KCFI,
Leak,
Memory,
MemTag,
SafeStack,
ShadowCallStack,
Thread,
}
impl fmt::Display for Sanitizer {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Sanitizer::Address => write!(f, "AddressSanitizer"),
Sanitizer::CFI => write!(f, "ControlFlowIntegrity"),
Sanitizer::HWAddress => write!(f, "HWAddressSanitizer"),
Sanitizer::KCFI => write!(f, "KernelControlFlowIntegrity"),
Sanitizer::Leak => write!(f, "LeakSanitizer"),
Sanitizer::Memory => write!(f, "MemorySanitizer"),
Sanitizer::MemTag => write!(f, "MemTagSanitizer"),
Sanitizer::SafeStack => write!(f, "SafeStack"),
Sanitizer::ShadowCallStack => write!(f, "ShadowCallStack"),
Sanitizer::Thread => write!(f, "ThreadSanitizer"),
}
}
}
impl Sanitizer {
const DEFAULT_RUSTFLAGS: &'static str = "-Copt-level=3";
fn run_tests(&self) {
if is_current_toolchain_nightly() {
group!("Sanitizer: {}", self.to_string());
let retriever = RustupTargetRetriever;
if self.is_target_supported(&retriever) {
let envs = vec![
(
"RUSTFLAGS",
format!("{} {}", self.flags(), Sanitizer::DEFAULT_RUSTFLAGS),
),
("RUSTDOCFLAGS", self.flags().to_string()),
];
let features = self.cargo_features();
let mut args = vec!["--", "--color=always", "--no-capture"];
args.extend(features);
run_cargo(
"test",
args.into(),
envs.into_iter().collect(),
"Failed to run cargo test",
);
} else {
info!("No supported target found for this sanitizer.");
}
endgroup!();
} else {
error!(
"You must use 'cargo +nightly' to run this check.
Install a nightly toolchain with 'rustup toolchain install nightly'."
)
}
}
fn flags(&self) -> &'static str {
match self {
Sanitizer::Address => "-Zsanitizer=address",
Sanitizer::CFI => "-Zsanitizer=cfi -Clto",
Sanitizer::HWAddress => "-Zsanitizer=hwaddress -Ctarget-feature=+tagged-globals",
Sanitizer::KCFI => "-Zsanitizer=kcfi",
Sanitizer::Leak => "-Zsanitizer=leak",
Sanitizer::Memory => "-Zsanitizer=memory -Zsanitizer-memory-track-origins",
Sanitizer::MemTag => "--Zsanitizer=memtag -Ctarget-feature=\"+mte\"",
Sanitizer::SafeStack => "-Zsanitizer=safestack",
Sanitizer::ShadowCallStack => "-Zsanitizer=shadow-call-stack",
Sanitizer::Thread => "-Zsanitizer=thread",
}
}
fn cargo_features(&self) -> Vec<&str> {
match self {
Sanitizer::CFI => vec!["-Zbuild-std", "--target x86_64-unknown-linux-gnu"],
_ => vec![],
}
}
fn supported_targets(&self) -> Vec<Target> {
match self {
Sanitizer::Address => vec![
Target::Aarch64AppleDarwin,
Target::Aarch64UnknownFuchsia,
Target::Aarch64UnknownLinuxGnu,
Target::X8664AppleDarwin,
Target::X8664UnknownFuchsia,
Target::X8664UnknownFreebsd,
Target::X8664UnknownLinuxGnu,
],
Sanitizer::CFI => vec![Target::X8664UnknownLinuxGnu],
Sanitizer::HWAddress => {
vec![Target::Aarch64LinuxAndroid, Target::Aarch64UnknownLinuxGnu]
}
Sanitizer::KCFI => vec![
Target::Aarch64LinuxAndroid,
Target::Aarch64UnknownLinuxGnu,
Target::X8664LinuxAndroid,
Target::X8664UnknownLinuxGnu,
],
Sanitizer::Leak => vec![
Target::Aarch64AppleDarwin,
Target::Aarch64UnknownLinuxGnu,
Target::X8664AppleDarwin,
Target::X8664UnknownLinuxGnu,
],
Sanitizer::Memory => vec![
Target::Aarch64UnknownLinuxGnu,
Target::X8664UnknownFreebsd,
Target::X8664UnknownLinuxGnu,
],
Sanitizer::MemTag => vec![Target::Aarch64LinuxAndroid, Target::Aarch64UnknownLinuxGnu],
Sanitizer::SafeStack => vec![Target::X8664UnknownLinuxGnu],
Sanitizer::ShadowCallStack => vec![Target::Aarch64LinuxAndroid],
Sanitizer::Thread => vec![
Target::Aarch64AppleDarwin,
Target::Aarch64UnknownLinuxGnu,
Target::X8664AppleDarwin,
Target::X8664UnknownFreebsd,
Target::X8664UnknownLinuxGnu,
],
}
}
// Returns true if the sanitizer is supported by the currently installed targets
fn is_target_supported<T: TargetRetriever>(&self, retriever: &T) -> bool {
let installed_targets = retriever.get_installed_targets();
let supported = self.supported_targets();
installed_targets.iter().any(|installed| {
let installed_target = Target::from_str(installed.trim()).unwrap_or(Target::Unknown);
supported.iter().any(|target| target == &installed_target)
})
}
}
// Constants for target names
const AARCH64_APPLE_DARWIN: &str = "aarch64-apple-darwin";
const AARCH64_LINUX_ANDROID: &str = "aarch64-linux-android";
const AARCH64_UNKNOWN_FUCHSIA: &str = "aarch64-unknown-fuchsia";
const AARCH64_UNKNOWN_LINUX_GNU: &str = "aarch64-unknown-linux-gnu";
const X8664_APPLE_DARWIN: &str = "x86_64-apple-darwin";
const X8664_LINUX_ANDROID: &str = "x86_64-linux-android";
const X8664_UNKNOWN_FUCHSIA: &str = "x86_64-unknown-fuchsia";
const X8664_UNKNOWN_FREEBSD: &str = "x86_64-unknown-freebsd";
const X8664_UNKNOWN_LINUX_GNU: &str = "x86_64-unknown-linux-gnu";
trait TargetRetriever {
fn get_installed_targets(&self) -> Vec<String>;
}
struct RustupTargetRetriever;
impl TargetRetriever for RustupTargetRetriever {
fn get_installed_targets(&self) -> Vec<String> {
rustup_get_installed_targets()
.lines()
.map(|s| s.to_string())
.collect()
}
}
// Represents Rust targets
// Remark: we list only the targets that are supported by sanitizers
#[derive(Debug, PartialEq)]
enum Target {
Aarch64AppleDarwin,
Aarch64LinuxAndroid,
Aarch64UnknownFuchsia,
Aarch64UnknownLinuxGnu,
X8664AppleDarwin,
X8664LinuxAndroid,
X8664UnknownFuchsia,
X8664UnknownFreebsd,
X8664UnknownLinuxGnu,
Unknown,
}
impl Target {
fn from_str(s: &str) -> Option<Self> {
match s {
AARCH64_APPLE_DARWIN => Some(Self::Aarch64AppleDarwin),
AARCH64_LINUX_ANDROID => Some(Self::Aarch64LinuxAndroid),
AARCH64_UNKNOWN_FUCHSIA => Some(Self::Aarch64UnknownFuchsia),
AARCH64_UNKNOWN_LINUX_GNU => Some(Self::Aarch64UnknownLinuxGnu),
X8664_APPLE_DARWIN => Some(Self::X8664AppleDarwin),
X8664_LINUX_ANDROID => Some(Self::X8664LinuxAndroid),
X8664_UNKNOWN_FUCHSIA => Some(Self::X8664UnknownFuchsia),
X8664_UNKNOWN_FREEBSD => Some(Self::X8664UnknownFreebsd),
X8664_UNKNOWN_LINUX_GNU => Some(Self::X8664UnknownLinuxGnu),
_ => None,
}
}
}
impl fmt::Display for Target {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let target_str = match self {
Target::Aarch64AppleDarwin => AARCH64_APPLE_DARWIN,
Target::Aarch64LinuxAndroid => AARCH64_LINUX_ANDROID,
Target::Aarch64UnknownFuchsia => AARCH64_UNKNOWN_FUCHSIA,
Target::Aarch64UnknownLinuxGnu => AARCH64_UNKNOWN_LINUX_GNU,
Target::X8664AppleDarwin => X8664_APPLE_DARWIN,
Target::X8664LinuxAndroid => X8664_LINUX_ANDROID,
Target::X8664UnknownFuchsia => X8664_UNKNOWN_FUCHSIA,
Target::X8664UnknownFreebsd => X8664_UNKNOWN_FREEBSD,
Target::X8664UnknownLinuxGnu => X8664_UNKNOWN_LINUX_GNU,
Target::Unknown => "",
};
write!(f, "{}", target_str)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
struct MockTargetRetriever {
mock_data: Vec<String>,
}
impl MockTargetRetriever {
fn new(mock_data: Vec<String>) -> Self {
Self { mock_data }
}
}
impl TargetRetriever for MockTargetRetriever {
fn get_installed_targets(&self) -> Vec<String> {
self.mock_data.clone()
}
}
#[rstest]
#[case(vec!["".to_string()], false)] // empty string
#[case(vec!["x86_64-pc-windows-msvc".to_string()], false)] // not supported target
#[case(vec!["x86_64-pc-windows-msvc".to_string(), "".to_string()], false)] // not supported target and empty string
#[case(vec!["x86_64-unknown-linux-gnu".to_string()], true)] // one supported target
#[case(vec!["aarch64-apple-darwin".to_string(), "x86_64-unknown-linux-gnu".to_string()], true)] // one unsupported target and one supported
fn test_is_target_supported(#[case] installed_targets: Vec<String>, #[case] expected: bool) {
let mock_retriever = MockTargetRetriever::new(installed_targets);
let sanitizer = Sanitizer::Memory;
assert_eq!(sanitizer.is_target_supported(&mock_retriever), expected);
}
#[test]
fn test_consistency_of_fmt_and_from_str_strings() {
let variants = vec![
Target::Aarch64AppleDarwin,
Target::Aarch64LinuxAndroid,
Target::Aarch64UnknownFuchsia,
Target::Aarch64UnknownLinuxGnu,
Target::X8664AppleDarwin,
Target::X8664LinuxAndroid,
Target::X8664UnknownFuchsia,
Target::X8664UnknownFreebsd,
Target::X8664UnknownLinuxGnu,
];
for variant in variants {
let variant_str = format!("{}", variant);
let parsed_variant = Target::from_str(&variant_str);
assert_eq!(Some(variant), parsed_variant);
}
}
}