burn/examples/wgan
Guillaume Lagrange cb0854c636
Remove autodiff from generate (#2759)
2025-01-31 09:18:51 -05:00
..
examples Remove autodiff from generate (#2759) 2025-01-31 09:18:51 -05:00
src Fix BackendValues in backend-comparison after removal of jit suffix (#2756) 2025-01-30 12:27:56 -05:00
Cargo.toml Clean up -jit suffix in feature flags and modules (#2705) 2025-01-28 09:05:48 -05:00
README.md Clean up -jit suffix in feature flags and modules (#2705) 2025-01-28 09:05:48 -05:00

README.md

Wasserstein Generative Adversarial Network

A burn implementation of examplar WGAN model to generate MNIST digits inspired by the PyTorch implementation. Please note that better performance maybe gained by adopting a convolution layer in some other models.

Usage

Training

# Cuda backend
cargo run --example wgan-mnist --release --features cuda

# Wgpu backend
cargo run --example wgan-mnist --release --features wgpu

# Tch GPU backend
export TORCH_CUDA_VERSION=cu121 # Set the cuda version
cargo run --example wgan-mnist --release --features tch-gpu

# Tch CPU backend
cargo run --example wgan-mnist --release --features tch-cpu

# NdArray backend (CPU)
cargo run --example wgan-mnist --release --features ndarray                # f32 - single thread
cargo run --example wgan-mnist --release --features ndarray-blas-openblas  # f32 - blas with openblas
cargo run --example wgan-mnist --release --features ndarray-blas-netlib    # f32 - blas with netlib

Generating

To generate a sample of images, you can use wgan-generate. The same feature flags are used to select a backend.

cargo run --example wgan-generate --release --features cuda