Doc fixes (#418)

This commit is contained in:
Dilshod Tadjibaev 2023-06-21 11:32:50 -05:00 committed by GitHub
parent 73a88d8209
commit fce45f51be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 58 additions and 47 deletions

View File

@ -3,7 +3,7 @@ use crate as burn;
use super::LRScheduler;
use crate::{config::Config, LearningRate};
/// Configuration to create a [noam](NoamScheduler) learning rate scheduler.
/// Configuration to create a [noam](NoamLRScheduler) learning rate scheduler.
#[derive(Config)]
pub struct NoamLRSchedulerConfig {
/// The initial learning rate.
@ -26,7 +26,7 @@ pub struct NoamLRScheduler {
}
impl NoamLRSchedulerConfig {
/// Initialize a new [noam](NoamScheduler) learning rate scheduler.
/// Initialize a new [noam](NoamLRScheduler) learning rate scheduler.
pub fn init(&self) -> NoamLRScheduler {
NoamLRScheduler {
warmup_steps: self.warmup_steps as f64,

View File

@ -64,7 +64,7 @@ macro_rules! module {
///
/// Modules should be created using the [derive](burn_derive::Module) attribute.
/// This will make your module trainable, savable and loadable via
/// [state](Module::state) and [load](Module::load).
/// `state` and `load`.
///
/// # Example
///

View File

@ -59,7 +59,7 @@ pub enum Conv1dPaddingConfig {
/// - weight: Tensor of shape [channels_out, channels_in, kernel_size] initialized from a uniform
/// distribution `U(-k, k)` where `k = sqrt(1 / channels_in * kernel_size)`
///
/// - bias: Tensor of shape [channels_out], initialized from a uniform distribution `U(-k, k)`
/// - bias: Tensor of shape `[channels_out]`, initialized from a uniform distribution `U(-k, k)`
/// where `k = sqrt(1 / channels_in * kernel_size)`
#[derive(Module, Debug)]
pub struct Conv1d<B: Backend> {

View File

@ -54,10 +54,10 @@ pub enum Conv2dPaddingConfig {
///
/// # Params
///
/// - weight: Tensor of shape [channels_out, channels_in, kernel_size_1, kernel_size_2] initialized from a uniform
/// - weight: Tensor of shape `[channels_out, channels_in, kernel_size_1, kernel_size_2]` initialized from a uniform
/// distribution `U(-k, k)` where `k = sqrt(1 / channels_in * kernel_size_1 * kernel_size_2)`
///
/// - bias: Tensor of shape [channels_out], initialized from a uniform distribution `U(-k, k)`
/// - bias: Tensor of shape `[channels_out]`, initialized from a uniform distribution `U(-k, k)`
/// where `k = sqrt(1 / channels_in * kernel_size_1 * kernel_size_2)`
#[derive(Module, Debug)]
pub struct Conv2d<B: Backend> {

View File

@ -50,8 +50,8 @@ impl Initializer {
/// # Params
///
/// - shape: Shape of the initiated tensor.
/// - fan_in: Option<usize>, the fan in to use in initialization formula, if needed
/// - fan_out: Option<usize>, the fan out to use in initialization formula, if needed
/// - fan_in: `Option<usize>`, the fan in to use in initialization formula, if needed
/// - fan_out: `Option<usize>`, the fan out to use in initialization formula, if needed
pub fn init_with<B: Backend, const D: usize, S: Into<Shape<D>>>(
&self,
shape: S,

View File

@ -22,8 +22,8 @@ impl<B: Backend> CrossEntropyLoss<B> {
///
/// # Shapes
///
/// - logits: [batch_size, num_targets]
/// - targets: [batch_size]
/// - logits: `[batch_size, num_targets]`
/// - targets: `[batch_size]`
pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
let [batch_size] = targets.dims();

View File

@ -77,7 +77,7 @@ impl LstmConfig {
}
}
/// Initialize a new [lstm](lstm) module with a [record](LstmRecord).
/// Initialize a new [lstm](Lstm) module with a [record](LstmRecord).
pub fn init_with<B: Backend>(&self, record: LstmRecord<B>) -> Lstm<B> {
let linear_config = LinearConfig {
d_input: self.d_input,

View File

@ -24,7 +24,7 @@ pub struct SgdConfig {
/// Optimizer that implements stochastic gradient descent with momentum.
///
/// Momentum is optional and can be [configured](SgdConfig::momentum).
/// The optimizer can be configured with [SgdConfig](SgdConfig).
pub struct Sgd<B: Backend> {
momentum: Option<Momentum<B>>,
weight_decay: Option<WeightDecay<B>>,

View File

@ -28,6 +28,6 @@ where
/// Change the device of the state.
///
/// This function will be called accordindly to have the state on the same device as the
/// gradient and the tensor when the [step](SimpleModuleOptimizer::step) function is called.
/// gradient and the tensor when the [step](SimpleOptimizer::step) function is called.
fn to_device<const D: usize>(state: Self::State<D>, device: &B::Device) -> Self::State<D>;
}

View File

@ -3,11 +3,11 @@ pub use burn_derive::Record;
use super::PrecisionSettings;
use serde::{de::DeserializeOwned, Serialize};
/// Trait to define a family of types which can be recorded using any [settings](RecordSettings).
/// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings).
pub trait Record: Send + Sync {
type Item<S: PrecisionSettings>: Serialize + DeserializeOwned;
/// Convert the current record into the corresponding item that follows the given [settings](RecordSettings).
/// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings).
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S>;
/// Convert the given item into a record.
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self;

View File

@ -21,7 +21,7 @@ pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Cl
/// Arguments used to load recorded objects.
type LoadArgs: Clone;
/// Record using the given [settings](RecordSettings).
/// Record an item with the given arguments.
fn record<R: Record>(
&self,
record: R,

View File

@ -43,7 +43,7 @@ where
/// Create from a json rows file (one json per line).
///
/// Supported field types: https://docs.rs/serde_json/latest/serde_json/value/enum.Value.html
/// [Supported field types](https://docs.rs/serde_json/latest/serde_json/value/enum.Value.html)
pub fn from_json_rows<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
let file = File::open(path)?;
let reader = BufReader::new(file);
@ -65,7 +65,7 @@ where
///
/// The supported field types are: String, integer, float, and bool.
///
/// See: https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde
/// See: [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde)
pub fn from_csv<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
let file = File::open(path)?;
let reader = BufReader::new(file);

View File

@ -68,13 +68,13 @@ impl From<&'static str> for SqliteDatasetError {
/// can be in any order.
///
/// For the supported field types, refer to:
/// - Serialization field types: https://docs.rs/serde_rusqlite/latest/serde_rusqlite
/// - SQLite data types: https://www.sqlite.org/datatype3.html
/// - [Serialization field types](https://docs.rs/serde_rusqlite/latest/serde_rusqlite)
/// - [SQLite data types](https://www.sqlite.org/datatype3.html)
///
/// 2. The fields in the `I` struct can be serialized into a single column `item` in the table. In this case, the table
/// should have a single column named `item` of type `BLOB`. This is useful when the `I` struct contains complex fields
/// that cannot be mapped to a SQLite type, such as nested structs, vectors, etc. The serialization is done using
/// MessagePack (https://msgpack.org/).
/// [MessagePack](https://msgpack.org/).
///
/// Note: The code automatically figures out which of the above two cases is applicable, and uses the appropriate
/// method to read the data from the table.
@ -490,7 +490,7 @@ where
/// Serializes and writes an item to the database. The item is written to the table for the
/// specified split. If the table does not exist, it is created. If the table exists, the item
/// is appended to the table. The serialization is done using the MessagePack (https://msgpack.org/)
/// is appended to the table. The serialization is done using the [MessagePack](https://msgpack.org/)
///
/// # Arguments
///

View File

@ -81,7 +81,7 @@ impl HuggingfaceDatasetLoader {
/// Specify a huggingface token to download datasets behind authentication.
///
/// You can get a token from https://huggingface.co/settings/tokens
/// You can get a token from [tokens settings](https://huggingface.co/settings/tokens)
pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self {
self.huggingface_token = Some(huggingface_token.to_string());
self

View File

@ -81,9 +81,9 @@ pub trait ModuleOps<B: Backend> {
///
/// # Shapes
///
/// x: [batch_size, channels_in, height, width],
/// weight: [channels_out, channels_in, kernel_size_1, kernel_size_2],
/// bias: [channels_out],
/// x: `[batch_size, channels_in, height, width]`,
/// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
/// bias: `[channels_out]`,
fn conv2d(
x: B::TensorPrimitive<4>,
weight: B::TensorPrimitive<4>,
@ -94,9 +94,9 @@ pub trait ModuleOps<B: Backend> {
///
/// # Shapes
///
/// x: [batch_size, channels_in, height, width],
/// weight: [channels_in, channels_out, kernel_size_1, kernel_size_2],
/// bias: [channels_out],
/// x: `[batch_size, channels_in, height, width]`,
/// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`,
/// bias: `[channels_out]`,
fn conv_transpose2d(
x: B::TensorPrimitive<4>,
weight: B::TensorPrimitive<4>,
@ -118,9 +118,9 @@ pub trait ModuleOps<B: Backend> {
///
/// # Shapes
///
/// x: [batch_size, channels_in, length],
/// weight: [channels_out, channels_in, kernel_size],
/// bias: [channels_out],
/// x: `[batch_size, channels_in, length]`,
/// weight: `[channels_out, channels_in, kernel_size]`,
/// bias: `[channels_out]`,
fn conv1d(
x: B::TensorPrimitive<3>,
weight: B::TensorPrimitive<3>,
@ -133,9 +133,9 @@ pub trait ModuleOps<B: Backend> {
///
/// # Shapes
///
/// x: [batch_size, channels_in, length],
/// weight: [channels_in, channels_out, length],
/// bias: [channels_out],
/// x: `[batch_size, channels_in, length]`,
/// weight: `[channels_in, channels_out, length]`,
/// bias: `[channels_out]`,
fn conv_transpose1d(
x: B::TensorPrimitive<3>,
weight: B::TensorPrimitive<3>,

View File

@ -7,7 +7,7 @@ use burn_core::tensor::backend::ADBackend;
/// Learner struct encapsulating all components necessary to train a Neural Network model.
///
/// To create a learner, use the [builder](crate::train::LearnerBuilder) struct.
/// To create a learner, use the [builder](crate::learner::LearnerBuilder) struct.
pub struct Learner<B, M, O, LR, TO, VO>
where
B: ADBackend,

View File

@ -144,8 +144,8 @@ where
self
}
/// Register a checkpointer that will save the [optimizer](crate::optim::Optimizer) and the
/// [model](crate::module::Module) [states](crate::module::State).
/// Register a checkpointer that will save the [optimizer](Optimizer) and the
/// [model](ADModule).
///
/// The number of checkpoints to be keep should be set to a minimum of two to be safe, since
/// they are saved and deleted asynchronously and a crash during training might make a

View File

@ -44,7 +44,7 @@ pub trait Metric: Send + Sync {
/// Adaptor are used to transform types so that they can be used by metrics.
///
/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
/// registed with the [leaner buidler](burn::train::LearnerBuilder).
/// registed with the [leaner buidler](crate::learner::LearnerBuilder) .
pub trait Adaptor<T> {
/// Adapt the type to be passed to a [metric](Metric).
fn adapt(&self) -> T;

View File

@ -1,6 +1,6 @@
use super::{MetricEntry, Numeric};
/// Usefull utility to implement numeric [metrics](crate::train::metric::Metric).
/// Usefull utility to implement numeric metrics.
///
/// # Notes
///

View File

@ -3,24 +3,29 @@
/// Options are:
/// - [Vulkan](Vulkan)
/// - [Metal](Metal)
/// - [OpenGL](OpenGL)
/// - [OpenGL](OpenGl)
/// - [DirectX 11](Dx11)
/// - [DirectX 12](Dx12)
/// - [WebGPU](WebGPU)
/// - [WebGpu](WebGpu)
pub trait GraphicsApi: Send + Sync + core::fmt::Debug + Default + Clone + 'static {
fn backend() -> wgpu::Backend;
}
#[derive(Default, Debug, Clone)]
pub struct Vulkan;
#[derive(Default, Debug, Clone)]
pub struct Metal;
#[derive(Default, Debug, Clone)]
pub struct OpenGl;
#[derive(Default, Debug, Clone)]
pub struct Dx11;
#[derive(Default, Debug, Clone)]
pub struct Dx12;
#[derive(Default, Debug, Clone)]
pub struct WebGpu;

View File

@ -10,7 +10,7 @@ use burn::tensor::Tensor;
use wasm_bindgen::prelude::*;
/// Mnist structure that corresponds to JavaScript class.
/// See: https://rustwasm.github.io/wasm-bindgen/contributing/design/exporting-rust-struct.html
/// See:[exporting-rust-struct](https://rustwasm.github.io/wasm-bindgen/contributing/design/exporting-rust-struct.html)
#[wasm_bindgen]
pub struct Mnist {
model: Model<Backend>,
@ -35,8 +35,8 @@ impl Mnist {
/// * `input` - A f32 slice of input 28x28 image
///
/// See bindgen support types for passing and returning arrays:
/// * https://rustwasm.github.io/wasm-bindgen/reference/types/number-slices.html
/// * https://rustwasm.github.io/wasm-bindgen/reference/types/boxed-number-slices.html
/// * [number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/number-slices.html)
/// * [boxed-number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/boxed-number-slices.html)
///
pub fn inference(&self, input: &[f32]) -> Result<Box<[f32]>, String> {
// Reshape from the 1D array to 3d tensor [batch, height, width]

View File

@ -52,12 +52,18 @@ build_and_test_all_features() {
echo "Build with all defaults"
cargo build --all-features
echo "Test with defaults"
echo "Test with all features"
cargo test --all-features
echo "Check documentation with all features"
cargo doc --all-features
cd .. || exit
}
# Set RUSTDOCFLAGS to treat warnings as errors for the documentation build
export RUSTDOCFLAGS="-D warnings"
# Save the script start time
start_time=$(date +%s)
@ -65,11 +71,11 @@ start_time=$(date +%s)
rustup target add wasm32-unknown-unknown
rustup target add thumbv7m-none-eabi
# TODO decide if we should "cargo clean" here.
cargo build --workspace
cargo test --workspace
cargo fmt --check --all
cargo clippy -- -D warnings
cargo doc --workspace
# no_std tests
build_and_test_no_std "burn"