burn/examples/modern-lstm
Nathaniel Simard d9e41460ff
Refactor burn jit => burn-cubecl (#2809)
2025-02-13 12:39:29 -05:00
..
examples modern lstm (#2752) 2025-02-03 11:24:15 -05:00
src Bump rand from 0.8.5 to 0.9.0 (#2789) 2025-02-12 10:07:55 -05:00
Cargo.toml Bump rand from 0.8.5 to 0.9.0 (#2789) 2025-02-12 10:07:55 -05:00
README.md Refactor burn jit => burn-cubecl (#2809) 2025-02-13 12:39:29 -05:00

README.md

Advanced LSTM Implementation with Burn

A more advanced implementation of Long Short-Term Memory (LSTM) networks in Burn with combined weight matrices for the input and hidden states, based on the PyTorch implementation.

LstmNetwork is the top-level module with bidirectional and regularization support. The LSTM variants differ by bidirectional and num_layers settings

  • LSTM: num_layers = 1 and bidirectional = false
  • Stacked LSTM: num_layers > 1 and bidirectional = false
  • Bidirectional LSTM: num_layers = 1 and bidirectional = true
  • Bidirectional Stacked LSTM: num_layers > 1 and bidirectional = true

This implementation is complementary to Burn's official LSTM, users can choose either one depends on the project's specific needs.

Usage

Training

# Cuda backend
cargo run --example lstm-train --release --features cuda

# Wgpu backend
cargo run --example lstm-train --release --features wgpu

# Tch GPU backend
export TORCH_CUDA_VERSION=cu121 # Set the cuda version
cargo run --example lstm-train --release --features tch-gpu

# Tch CPU backend
cargo run --example lstm-train --release --features tch-cpu

# NdArray backend (CPU)
cargo run --example lstm-train --release --features ndarray
cargo run --example lstm-train --release --features ndarray-blas-openblas
cargo run --example lstm-train --release --features ndarray-blas-netlib

Inference

cargo run --example lstm-infer --release --features cuda