mirror of https://github.com/tracel-ai/burn.git
![]() |
||
---|---|---|
.. | ||
examples | ||
src | ||
Cargo.toml | ||
README.md |
README.md
Wasserstein Generative Adversarial Network
A burn implementation of examplar WGAN model to generate MNIST digits inspired by the PyTorch implementation. Please note that better performance maybe gained by adopting a convolution layer in some other models.
Usage
Training
# Cuda backend
cargo run --example wgan-mnist --release --features cuda
# Wgpu backend
cargo run --example wgan-mnist --release --features wgpu
# Tch GPU backend
export TORCH_CUDA_VERSION=cu121 # Set the cuda version
cargo run --example wgan-mnist --release --features tch-gpu
# Tch CPU backend
cargo run --example wgan-mnist --release --features tch-cpu
# NdArray backend (CPU)
cargo run --example wgan-mnist --release --features ndarray # f32 - single thread
cargo run --example wgan-mnist --release --features ndarray-blas-openblas # f32 - blas with openblas
cargo run --example wgan-mnist --release --features ndarray-blas-netlib # f32 - blas with netlib
Generating
To generate a sample of images, you can use wgan-generate
. The same feature flags are used to select a backend.
cargo run --example wgan-generate --release --features cuda