Many superficial fixes to the contributor book (#1644)

* wip

* many fixes in the contributor book

* undo candle modif

* oops candle changes shouldnt have been there

* typo

* fix commands
This commit is contained in:
Louis Fortier-Dubois 2024-04-16 17:17:11 -04:00 committed by GitHub
parent 0ee2021567
commit e4b80bad5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 289 additions and 612 deletions

View File

@ -1,230 +0,0 @@
<!--
TODO: Add the following sections:
# Tenets
# Design Philosophy
-->
# Architecture
This file documents most major architectural decisions with the reasoning behind them.
**Sections**
- [Architecture](#architecture)
- [Module](#module)
- [Optimization](#optimization)
- [Constraints](#constraints)
- [Solution](#solution)
- [Serialization](#serialization)
- [Constraints](#constraints-1)
- [Solution](#solution-1)
- [Pros](#pros)
- [Cons](#cons)
- [Compatibility](#compatibility)
- [Tensor](#tensor)
- [Backend](#backend)
- [Autodiff](#autodiff)
## Module
Modules are a way of creating neural network structures that can be easily optimized, saved, and loaded with little to no boilerplate.
Unlike other frameworks, a module does not force the declaration of the forward pass, leaving it up to the implementer to decide how it should be defined.
Additionally, most modules are created using a (de)serializable configuration, which defines the structure of the module and its hyper-parameters.
Parameters and hyper-parameters are not serialized into the same file and both are normally necessary to load a module for inference.
### Optimization
Optimization is normally done with gradient descent (or ascent for reinforcement learning), and it is important to provide an easy API for optimizing modules.
#### Constraints
1. **Users should be able to control what is optimized.**
Modules can contain anything for maximum flexibility, but not everything needs to be optimized.
2. **Optimizers should have a serializable state that is updated during training.**
Many optimizers keep track of previous gradients to implement some form of momentum.
However, the state can be anything, not just tensors, allowing for easy implementation of any kind of optimizer.
3. **The learning rate can be updated during training.**
Learning rate schedulers are often used during training and should be considered as a key aspect.
#### Solution
The solution to this problem comprises multiple parts.
Firstly, the `Optimizer` trait is quite similar to the `Module` trait in terms of saving and loading the state.
Please refer to the [serialization](#serialization) section for more details.
Secondly, two traits were created.
The `Optimizer` trait is general and relatively unopinionated, with a simple `step` method that takes a learning rate, a module, and the gradients.
The other trait, `SimpleOptimizer`, aims to provide an easier API for implementing new optimizers.
The goal is to allow implementations to avoid handling missing gradients, loading and exporting records, navigating the module parameter structure, handling tracked and untracked tensors, and other such tasks.
Thirdly, each tensor that will be optimized needs to be wrapped into a `Param` struct, which gives them an ID used for (de)serialization and to associate the state of the optimizer to each parameter.
The `Module` trait has two ways to navigate over parameters.
The first one is the `map` function, which returns `Self` and makes it easy to implement any transformation and mutate all parameters.
The second one is the `visit` function, which has a similar signature but does not mutate the parameter tensors.
**SimpleOptimizer**
The `SimpleOptimizer` has two major assumptions:
1. The state of the optimizer is linked to each parameter.
In other words, each parameter has its own optimizer state, decoupled from the other parameters.
2. The state of the optimizer implements `Record`, `Clone`, and has a `'static` lifetime.
The benefits of those assumptions materialize in simplicity with little loss in flexibility.
The state associative type is also generic over the dimension, making it extremely easy to include tensors in the state that share the same dimensionality as its parameter.
To wrap a simple optimizer into the more general `Optimizer` trait, the `OptimizerAdaptor` struct is used.
**OptimizerAdaptor**
The `OptimizerAdaptor` is a simple struct composed of a `SimpleOptimizer` and a hashmap with all records associated with each parameter ID.
When performing an optimization step, the adaptor handles the following:
1. Updates each parameter tensor in the given module using the `Module::map` function.
2. Checks if a gradient for the current tensor exists.
3. Makes sure that the gradient, the tensor, and the optimizer state associated with the current parameter are on the same device.
The device can be different if the state is loaded from disk to restart training.
4. Performs the simple optimizer step using the inner tensor since the operations done by the optimizer should not be tracked in the autodiff graph.
5. Updates the state for the current parameter and returns the updated tensor, making sure it's properly registered into the autodiff graph if gradients are marked as required.
Note that a parameter can still be updated by another process, as is the case with running metrics used in batch norm.
These tensors are still wrapped using the `Param` struct so that they are included in the module's state and given a proper parameter ID, but they are not registered in the autodiff graph.
### Serialization
An important aspect of a deep learning framework is the ability to save and load models from disk.
Despite appearing as a simple feature, it involves numerous constraints that require a proper solution.
#### Constraints
1. **Users should be able to declare the precision of the model to be saved, independent of the backend in use.**
The modules should not be duplicated in RAM in another precision to support this.
Conversion should be done lazily during (de)serialization.
2. **Users should be able to add any field to a module, even fields that are not serializable.**
This can include constants, database connections, other module references, or any other information.
Only parameters should be serialized since the structure of the module itself should be encapsulated with module configurations (hyper-parameters).
3. **Users should be able to declare the format in which the module should be saved.**
This can involve saving to a compressed JSON file or directly to bytes in memory for `no-std` environments.
4. **Users should be able to create a module with its saved parameters without having to initialize the module first.**
This will avoid unnecessary module initialization and tensor loading, resulting in reduced cold start when dealing with inference.
In addition to all of these constraints, the solution should be easy to use.
#### Solution
In order to be able to add any field to a module without requiring it to be (de)serializable, we decouple the module type from its state.
We create a new type for each module that only contains the parameters that need to be saved.
To generate that type automatically, the user must either declare which field is a parameter or a constant, or we assume that each field implements the module trait.
The second solution was chosen as it simplifies the code generation and reduces the size of the user API.
This means that the `Module` trait should be implemented by [primitives types](./burn-core/src/module/param/primitive.rs).
The following diagrams highlight the main types and traits used in the solution.
<div align="center">
<h4>Module Serialization Types</h4>
<img src="./assets/ModuleSerialization.png" width="700px"/>
<div align="left">
The way the types interact with each other is pretty straightforward.
First, a module can be converted into a record using `into_record()`.
Note that tensors can be cloned, but it won't actually copy any data; it will create another reference to the same data.
Then, a `Recorder` instance can be used to serialize any record.
The `Recorder` has the `PrecisionSettings` type as associate type, so any record will be serialized using the settings provided at the creation of the `Recorder` instance.
Note that tensors implement record, and their item is just a wrapper struct that contains information about the precision in which the tensor should be saved or loaded.
No actual copy of the tensor is made until this point.
The tensor is converted to the `Data` struct and then converted into the specified precision only when `serialize()` or `deserialize()` are called, which makes the whole process lazy.
To recapitulate, the `Module` trait has an associated type that implements `Record`, which only contains the parameters of the model.
The `Record` trait has a generic associated type (GAT) that specifies a family of types that can be (de)serialized given any `PrecisionSettings`.
Records are therefore decoupled from the backend in use, and the saved items can be loaded on any backend with any precision, since the conversion is type-safe and done when `serialize()` and `deserialize()` are called.
All of the types are generated using simple derive macros without any conditional statements or complex syntax, as `Record` and `Module` are implemented for all primitive types.
This makes the code simple and easy to maintain.
In addition, you can extend the current system with your own `Recorder` and `PrecisionSettings` to control how your modules should be saved and loaded.
##### Pros
- All constraints are respected.
- The code is simple and easy to maintain, with very few conditional statements.
It is just recursive data structures, where all the complexity is handled by the framework in primitive implementations.
- The user API is simple and small, with only two derives (`Record` and `Module`) and no additional attributes.
- Users can create their own `Module` and `Record` primitive types, which gives them the flexibility to control how their data is serialized without having to fork the framework.
##### Cons
- There are more types, but most of them are automatically generated and single-purpose, so users don't need to interact with them for common use cases.
However, they can do so if necessary.
- When instantiating a new record manually, each field must be set to something, even if the type itself is `()`, which represents no value.
Since the code generation step uses associative types, it doesn't know that a field type is actually nothing.
Creating a record manually without using the generated function `into_record` or loading it from a file is only useful to load a set of parameters into a module from an arbitrary source.
Using the record may not be the optimal solution to this problem, and another API could be created in the future.
##### Compatibility
Record may become incompatible with previous versions of Burn, depending on the chosen format.
The more compact format (bincode) store minimal information about the type, making it significantly smaller but less resilient to type changes such adding an optional field.
At some point, it might be necessary to provide a translation script that can translate a more resilient format from a previous version to a more compact one.
### Tensor
A proper deep learning framework should have a fast tensor implementation with autodiff support, and Burn is no exception.
The tensor API abstracts away backend implementation details and focuses on usability without compromising performance.
To make it as easy as possible to use, there is only one tensor type, which is different from multiple tensor and deep learning crates in Rust.
Generic parameters are used instead to specialize the tensor type.
- **B: Backend:**
The first argument is the backend on which the tensor implementation lies.
- **const D: usize:**
The second argument is the dimensionality of the tensor.
- **K: TensorKind:**
The third argument is the tensor kind, which can be either Float, Int or Bool.
By default, the tensor kind is set to Float, so for most tensors, the kind argument is not necessary.
Having one struct for tensors reduces the complexity of the tensor API, which also means less duplicated documentation to write and maintain.
Tensors are thread-safe, which means that you can send a tensor to another thread, and everything will work, including auto-differentiation.
Note that there are no in-place tensor operations since all tensor operations take owned tensors as parameters, which make it possible to mutate them.
Tensors can be shared simply by cloning them, but if there is only one reference to a tensor, the backend implementation is free to reuse the tensor's allocated data.
For more information about how it is done, you can have a look at this [blog post](https://burn.dev/blog/burn-rusty-approach-to-tensor-handling).
#### Backend
The Backend trait abstracts multiple things:
- Device type
- Float tensor type
- Bool tensor type
- Int tensor type
- Float element type
- Int element type
- Float tensor operations (kernels)
- Int tensor operations (kernels)
- Bool tensor operations (kernels)
Even though having one type for tensors is convenient for the tensor API, it can be cumbersome when implementing a backend.
Therefore, backends can decide, through associated types, what types they want to use for their int, float, and bool tensors.
Since float and int can have multiple precisions, the float and int element types are also associated types that must be declared by the backend.
Note that the backend chooses the precision and not the user.
Since not all backends will support the same element types, no assumptions must be made.
Therefore, there are no methods on tensors to change the precision, except for the `to_full_precision` function, which ensures numerical stability on the current backend.
Backend implementations can provide a way to choose the precision, which can be accomplished with a generic parameter (e.g. `NdArray<f32>`).
To be as general as possible, tensor operations are implemented as plain functions.
There is no object or self, just functions that take tensors as input and often return tensors as output as well.
Backend implementations are free to use their own patterns to implement these kernels.
Note that Burn is a dynamic graph deep learning framework, so backends may have to implement asynchronous kernel executions for performance reasons.
#### Autodiff
As of now, there is only one backend decorator that supports autodiff.
It follows the decorator pattern, making any backend differentiable.
However, the `AutodiffBackend` trait abstracts how gradients are calculated, and other approaches to autodiff might be added later.
For more information about how the current autodiff backend works, you can read this [blog post](https://burn.dev/blog/burn-rusty-approach-to-tensor-handling).

15
Cargo.lock generated
View File

@ -714,9 +714,8 @@ dependencies = [
[[package]] [[package]]
name = "candle-core" name = "candle-core"
version = "0.4.1" version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/huggingface/candle.git?branch=main#af955f260cc20364b3e000c895dcb134a46e4e94"
checksum = "6f1b20174c1707e20f4cb364a355b449803c03e9b0c9193324623cf9787a4e00"
dependencies = [ dependencies = [
"accelerate-src", "accelerate-src",
"byteorder", "byteorder",
@ -741,18 +740,16 @@ dependencies = [
[[package]] [[package]]
name = "candle-kernels" name = "candle-kernels"
version = "0.4.1" version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/huggingface/candle.git?branch=main#af955f260cc20364b3e000c895dcb134a46e4e94"
checksum = "5845911a44164ebb73b56a0e23793ba1b583bad102af7400fe4768babc5815b2"
dependencies = [ dependencies = [
"bindgen_cuda", "bindgen_cuda",
] ]
[[package]] [[package]]
name = "candle-metal-kernels" name = "candle-metal-kernels"
version = "0.4.1" version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/huggingface/candle.git?branch=main#af955f260cc20364b3e000c895dcb134a46e4e94"
checksum = "b20d6c0d49121e2709ed9faa958ba915ea59526036bcf27558817d1452a4ff09"
dependencies = [ dependencies = [
"metal", "metal",
"once_cell", "once_cell",

View File

@ -1,127 +0,0 @@
<mxfile host="app.diagrams.net" modified="2023-05-04T20:31:00.285Z" agent="Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36" etag="TZ0g_z0E95Old5D0NmzO" version="21.2.1" type="device">
<diagram name="Page-1" id="p9OtIezBOMlZGQ46mofZ">
<mxGraphModel dx="2245" dy="2369" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
<root>
<mxCell id="0" />
<mxCell id="1" parent="0" />
<mxCell id="0A-74KZi9Q9iTp0nixR_-37" value="" style="rounded=1;whiteSpace=wrap;html=1;strokeWidth=3;" parent="1" vertex="1">
<mxGeometry x="950" y="-670" width="260" height="270" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-9" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=0;entryDx=0;entryDy=0;strokeWidth=3;endArrow=open;endFill=0;endSize=14;fontSize=18;" parent="1" source="0A-74KZi9Q9iTp0nixR_-1" target="0A-74KZi9Q9iTp0nixR_-3" edge="1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-1" value="Module" style="rounded=1;whiteSpace=wrap;html=1;fontSize=18;fontStyle=1;strokeWidth=3;fillColor=#fff2cc;strokeColor=#d6b656;" parent="1" vertex="1">
<mxGeometry x="30" y="-1140" width="170" height="70" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-7" style="edgeStyle=orthogonalEdgeStyle;rounded=1;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;strokeWidth=3;endArrow=diamondThin;endFill=0;fontStyle=1;endSize=20;fontSize=18;" parent="1" source="0A-74KZi9Q9iTp0nixR_-3" target="0A-74KZi9Q9iTp0nixR_-6" edge="1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-3" value="&lt;span style=&quot;font-size: 18px;&quot;&gt;Module&lt;br style=&quot;font-size: 18px;&quot;&gt;Trait&lt;br style=&quot;font-size: 18px;&quot;&gt;&lt;/span&gt;" style="rhombus;whiteSpace=wrap;html=1;strokeWidth=3;fontSize=18;fontStyle=1;fillColor=#f8cecc;strokeColor=#b85450;" parent="1" vertex="1">
<mxGeometry x="40" y="-990" width="150" height="150" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-30" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=3;endArrow=diamondThin;endFill=0;endSize=20;fontSize=18;" parent="1" source="0A-74KZi9Q9iTp0nixR_-6" target="0A-74KZi9Q9iTp0nixR_-29" edge="1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-6" value="&lt;span style=&quot;font-size: 18px;&quot;&gt;Record&lt;br style=&quot;font-size: 18px;&quot;&gt;Trait&lt;br style=&quot;font-size: 18px;&quot;&gt;&lt;/span&gt;" style="rhombus;whiteSpace=wrap;html=1;strokeWidth=3;fontSize=18;fontStyle=1;fillColor=#f8cecc;strokeColor=#b85450;" parent="1" vertex="1">
<mxGeometry x="290" y="-990" width="150" height="150" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-24" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=0;entryDx=0;entryDy=0;strokeWidth=3;endArrow=open;endFill=0;endSize=14;fontSize=18;" parent="1" source="0A-74KZi9Q9iTp0nixR_-10" target="0A-74KZi9Q9iTp0nixR_-6" edge="1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-10" value="Module&lt;br style=&quot;font-size: 18px;&quot;&gt;Record" style="rounded=1;whiteSpace=wrap;html=1;fontSize=18;fontStyle=1;strokeWidth=3;fillColor=#fff2cc;strokeColor=#d6b656;" parent="1" vertex="1">
<mxGeometry x="280" y="-1140" width="170" height="70" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-12" style="edgeStyle=orthogonalEdgeStyle;rounded=1;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;strokeWidth=3;endArrow=diamondThin;endFill=0;fontStyle=1;endSize=20;" parent="1" target="0A-74KZi9Q9iTp0nixR_-13" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="973" y="-585" as="sourcePoint" />
<mxPoint x="1053" y="-585" as="targetPoint" />
</mxGeometry>
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-13" value="Associative type" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontSize=16;fontStyle=1;spacingLeft=6;" parent="1" vertex="1">
<mxGeometry x="1040" y="-600" width="140" height="30" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-14" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=3;endArrow=open;endFill=0;endSize=14;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" parent="1" target="0A-74KZi9Q9iTp0nixR_-15" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="973" y="-540" as="sourcePoint" />
<mxPoint x="1023" y="-540.5" as="targetPoint" />
</mxGeometry>
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-15" value="Implement trait" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontSize=16;fontStyle=1;spacingLeft=6;" parent="1" vertex="1">
<mxGeometry x="1040" y="-555" width="140" height="30" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-16" value="&lt;span&gt;&lt;br&gt;&lt;/span&gt;" style="rhombus;whiteSpace=wrap;html=1;strokeWidth=3;fontSize=18;fontStyle=1;fillColor=#f8cecc;strokeColor=#b85450;" parent="1" vertex="1">
<mxGeometry x="990" y="-508" width="30" height="30" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-17" value="Trait" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontSize=16;fontStyle=1;spacingLeft=6;" parent="1" vertex="1">
<mxGeometry x="1040" y="-508" width="140" height="30" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-18" value="" style="rounded=1;whiteSpace=wrap;html=1;fontSize=22;fontStyle=1;strokeWidth=3;fillColor=#fff2cc;strokeColor=#d6b656;" parent="1" vertex="1">
<mxGeometry x="985" y="-460" width="40" height="20" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-19" value="Type" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontSize=16;fontStyle=1;spacingLeft=6;" parent="1" vertex="1">
<mxGeometry x="1040" y="-465" width="140" height="30" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-26" value="&lt;span style=&quot;font-size: 18px;&quot;&gt;Serialize&lt;br style=&quot;font-size: 18px;&quot;&gt;&lt;/span&gt;" style="rhombus;whiteSpace=wrap;html=1;strokeWidth=3;fontSize=18;fontStyle=1;fillColor=#f8cecc;strokeColor=#b85450;" parent="1" vertex="1">
<mxGeometry x="220" y="-545" width="120" height="120" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-27" value="&lt;span style=&quot;font-size: 18px;&quot;&gt;Deserialize&lt;br style=&quot;font-size: 18px;&quot;&gt;&lt;/span&gt;" style="rhombus;whiteSpace=wrap;html=1;strokeWidth=3;fontSize=18;fontStyle=1;fillColor=#f8cecc;strokeColor=#b85450;" parent="1" vertex="1">
<mxGeometry x="390" y="-545" width="120" height="120" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-31" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=3;endArrow=open;endFill=0;endSize=14;fontSize=18;" parent="1" source="0A-74KZi9Q9iTp0nixR_-29" target="0A-74KZi9Q9iTp0nixR_-26" edge="1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-32" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=0;entryDx=0;entryDy=0;strokeWidth=3;endArrow=open;endFill=0;endSize=14;fontSize=18;" parent="1" source="0A-74KZi9Q9iTp0nixR_-29" target="0A-74KZi9Q9iTp0nixR_-27" edge="1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-29" value="Record Item&lt;br style=&quot;font-size: 18px;&quot;&gt;&amp;lt;&lt;font color=&quot;#b85451&quot;&gt;Precision&lt;font style=&quot;font-size: 18px;&quot;&gt;Settings&lt;/font&gt;&lt;/font&gt;&amp;gt;" style="rounded=1;whiteSpace=wrap;html=1;fontSize=18;fontStyle=1;strokeWidth=3;fillColor=#fff2cc;strokeColor=#d6b656;" parent="1" vertex="1">
<mxGeometry x="240" y="-750" width="250" height="100" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-41" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;endArrow=diamondThin;endFill=0;endSize=20;strokeWidth=3;fontSize=18;" parent="1" source="0A-74KZi9Q9iTp0nixR_-33" target="0A-74KZi9Q9iTp0nixR_-35" edge="1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-42" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;endArrow=diamondThin;endFill=0;endSize=20;strokeWidth=3;fontSize=18;" parent="1" source="0A-74KZi9Q9iTp0nixR_-33" target="0A-74KZi9Q9iTp0nixR_-34" edge="1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-33" value="&lt;span style=&quot;font-size: 18px;&quot;&gt;Precision&lt;br style=&quot;font-size: 18px;&quot;&gt;Settings&lt;br style=&quot;font-size: 18px;&quot;&gt;&lt;/span&gt;" style="rhombus;whiteSpace=wrap;html=1;strokeWidth=3;fontSize=18;fontStyle=1;fillColor=#f8cecc;strokeColor=#b85450;" parent="1" vertex="1">
<mxGeometry x="890" y="-1150" width="150" height="150" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-34" value="Float Elem" style="rounded=1;whiteSpace=wrap;html=1;fontSize=18;fontStyle=1;strokeWidth=3;fillColor=#fff2cc;strokeColor=#d6b656;" parent="1" vertex="1">
<mxGeometry x="1040" y="-980" width="170" height="70" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-35" value="Int Elem" style="rounded=1;whiteSpace=wrap;html=1;fontSize=18;fontStyle=1;strokeWidth=3;fillColor=#fff2cc;strokeColor=#d6b656;" parent="1" vertex="1">
<mxGeometry x="1040" y="-890" width="170" height="70" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-39" value="&lt;b&gt;&lt;font style=&quot;font-size: 23px;&quot;&gt;Legend&lt;/font&gt;&lt;/b&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
<mxGeometry x="950" y="-650" width="260" height="30" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-44" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;endArrow=diamondThin;endFill=0;endSize=20;strokeWidth=3;fontSize=18;" parent="1" source="0A-74KZi9Q9iTp0nixR_-47" target="0A-74KZi9Q9iTp0nixR_-50" edge="1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-45" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;endArrow=diamondThin;endFill=0;endSize=20;strokeWidth=3;fontSize=18;" parent="1" source="0A-74KZi9Q9iTp0nixR_-47" target="0A-74KZi9Q9iTp0nixR_-49" edge="1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-46" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;endArrow=diamondThin;endFill=0;endSize=20;strokeWidth=3;fontSize=18;" parent="1" source="0A-74KZi9Q9iTp0nixR_-47" target="0A-74KZi9Q9iTp0nixR_-48" edge="1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="b0dmkdkAH7MZCtueex0P-2" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;strokeWidth=3;endArrow=diamondThin;endFill=0;endSize=20;" edge="1" parent="1" source="0A-74KZi9Q9iTp0nixR_-47" target="b0dmkdkAH7MZCtueex0P-1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-47" value="&lt;span style=&quot;font-size: 18px;&quot;&gt;Recorder&lt;br style=&quot;font-size: 18px;&quot;&gt;&lt;/span&gt;" style="rhombus;whiteSpace=wrap;html=1;strokeWidth=3;fontSize=18;fontStyle=1;fillColor=#f8cecc;strokeColor=#b85450;" parent="1" vertex="1">
<mxGeometry x="520" y="-1150" width="150" height="150" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-48" value="Record Args" style="rounded=1;whiteSpace=wrap;html=1;fontSize=18;fontStyle=1;strokeWidth=3;fillColor=#fff2cc;strokeColor=#d6b656;" parent="1" vertex="1">
<mxGeometry x="670" y="-980" width="170" height="70" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-49" value="Record Output" style="rounded=1;whiteSpace=wrap;html=1;fontSize=18;fontStyle=1;strokeWidth=3;fillColor=#fff2cc;strokeColor=#d6b656;" parent="1" vertex="1">
<mxGeometry x="670" y="-890" width="170" height="70" as="geometry" />
</mxCell>
<mxCell id="0A-74KZi9Q9iTp0nixR_-50" value="Load Args" style="rounded=1;whiteSpace=wrap;html=1;fontSize=18;fontStyle=1;strokeWidth=3;fillColor=#fff2cc;strokeColor=#d6b656;" parent="1" vertex="1">
<mxGeometry x="670" y="-800.5" width="170" height="70" as="geometry" />
</mxCell>
<mxCell id="b0dmkdkAH7MZCtueex0P-1" value="Precision Settings" style="rounded=1;whiteSpace=wrap;html=1;fontSize=18;fontStyle=1;strokeWidth=3;fillColor=#fff2cc;strokeColor=#d6b656;" vertex="1" parent="1">
<mxGeometry x="670" y="-710" width="170" height="70" as="geometry" />
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>

View File

@ -2,8 +2,8 @@
- [How to Read This Book](./how-to-read-this-book.md) - [How to Read This Book](./how-to-read-this-book.md)
- [Getting Started](./getting-started/ReadMe.md) - [Getting Started](./getting-started/ReadMe.md)
- [Setting Up The Environment](./getting-started/setting-up-the-environment.md) - [Setting Up The Environment](./getting-started/setting-up-the-environment.md)
- [Configuring Your Editor(Optional)](./getting-started/configuring-your-editor.md) - [Configuring Your Editor (Optional)](./getting-started/configuring-your-editor.md)
- [testing](./getting-started/testing.md) - [Testing](./getting-started/testing.md)
- [Architecture Overview](./project-architecture/ReadMe.md) - [Architecture Overview](./project-architecture/ReadMe.md)
- [Modules](./project-architecture/module.md) - [Modules](./project-architecture/module.md)
- [Serialization](./project-architecture/serialization.md) - [Serialization](./project-architecture/serialization.md)
@ -12,5 +12,5 @@
- [Guides for Contributors](./guides/ReadMe.md) - [Guides for Contributors](./guides/ReadMe.md)
- [Onnx To Burn Conversion Tool: A Development Guide](./guides/onnx-to-burn-conversion-tool.md) - [Onnx To Burn Conversion Tool: A Development Guide](./guides/onnx-to-burn-conversion-tool.md)
- [Adding a New Operation to Burn](./guides/adding-a-new-operation-to-burn.md) - [Adding a New Operation to Burn](./guides/adding-a-new-operation-to-burn.md)
- [Frequently Encountered Issues](./guides/frequently-encountered-issues/ReadMe.md) - [Frequently Encountered Issues](./frequently-encountered-issues/ReadMe.md)
- [Issues Related To Adding Operators](./guides/frequently-encountered-issues/issues-while-adding-ops.md) - [Issues Related To Adding Operators](./frequently-encountered-issues/issues-while-adding-ops.md)

View File

@ -1,3 +1,5 @@
# Frequently Encountered Issues # Frequently Encountered Issues
This is a collection of issues people have encountered and asked about on the discord, and is separate from the guides since it often involves a wall of text that is only relevant to a small subset of contributors. This is a collection of issues people have encountered and asked about on the
[Discord server](https://discord.gg/uPEBbYYDB6). This section is separated from the guides since it
can involve lots of details that are only relevant to a small subset of contributors.

View File

@ -1,8 +1,9 @@
# Issues encountered while adding ops # Issues encountered while adding ops
Below are some of the issues that were encountered while adding ops to the project. If you encounter Below are some of the issues that were encountered while adding ops to the project. If you encounter
an issue while adding an op that isn't listed here, and it's not obvious how to fix it, please add an issue while adding an op that isn't listed here, and it's not obvious how to fix it, you can add
it to this list. Also, reach out on the [discord server](https://discord.gg/uPEBbYYDB6) if you need help. it to this list or reach out on the [Discord server](https://discord.gg/uPEBbYYDB6) if you need
help.
## Off by .000001 errors ## Off by .000001 errors
@ -12,9 +13,10 @@ it to this list. Also, reach out on the [discord server](https://discord.gg/uPEB
tests::maxmin::tests::test_mean_dim_2d stdout ---- thread 'tests::maxmin::tests::test_mean_dim_2d' panicked at burn-wgpu/src/lib.rs:49:5: assertion `left == right` failed left: Data { value: [1.0, 4.0], shape: Shape { dims: [2, 1] } } right: Data { value: [0.99999994, 3.9999998], shape: Shape { dims: [2, 1] } } tests::maxmin::tests::test_mean_dim_2d stdout ---- thread 'tests::maxmin::tests::test_mean_dim_2d' panicked at burn-wgpu/src/lib.rs:49:5: assertion `left == right` failed left: Data { value: [1.0, 4.0], shape: Shape { dims: [2, 1] } } right: Data { value: [0.99999994, 3.9999998], shape: Shape { dims: [2, 1] } }
``` ```
If you encounter this, swap out the `assert_eq!` in the failing test for `tensor1.assert_approx_eq` with `3` as the second argument. The second arguments specifies the level of precision. `3` is equivalent to a less than 0.001 difference between the elements of the two tensors. If you encounter this, swap out the `assert_eq!` in the failing test for
`tensor1.to_data().assert_approx_eq` with `3` as the second argument. The second arguments specifies
the level of precision: `3` is equivalent to a less than 10<sup>-3</sup> (0.001) difference between
the elements of the two tensors.
## Mismatched types and missing functions ## Mismatched types and missing functions
@ -29,13 +31,13 @@ error[E0599]: no method named `powi` found for struct `Tensor` in the current sc
For more information about an error, try `rustc --explain E0308`. error: could not compile `onnx-tests` (test "onnx_tests") due to 3 previous errors For more information about an error, try `rustc --explain E0308`. error: could not compile `onnx-tests` (test "onnx_tests") due to 3 previous errors
``` ```
If you are getting this error, you probably didn't implement your operator for the actual Tensor struct. If you are getting this error, you probably didn't implement your operator for the actual Tensor
This issue was encountered when adding the Pow operator. The operation was added to the struct. This issue was encountered when adding the Pow operator. The operation was added to the
`FloatTensorOps` and `IntTensorOp` traits, but not for the numeric trait (under `FloatTensorOps` and `IntTensorOp` traits, but not for the numeric trait (under
`burn-tensor/src/tensor/api/numeric.rs`). This, coupled with `powf` existing prior to the PR though `burn-tensor/src/tensor/api/numeric.rs`). This, coupled with `powf` existing prior to the PR though
only for scalar values (which had been renamed, just not in the right place), led to this confusing only for scalar values (which had been renamed, just not in the right place), led to this confusing
issue where it looked like the function was found, but the type was wrong. If that's the case, make issue where it looked like the function was found, but the type was wrong. If that's the case, make
sure that it's implemented for the appropriate type, in this case `Float` under sure that it's implemented for the appropriate type, in this case `Float` under
[burn-tensor/src/tensor/api/numeric.rs](https://github.com/tracel-ai/burn/blob/4ca3e31601228952bb1c1492bc9cd2adf15b5cf1/burn-tensor/src/tensor/api/numeric.rs#L2186), [crates/burn-tensor/src/tensor/api/numeric.rs](https://github.com/tracel-ai/burn/blob/1235b06e25e39a6ee5a4ac59f7f1d3da2ddb9bc3/crates/burn-tensor/src/tensor/api/numeric.rs),
and calling the `TensorOp.foo_op` defined under and calling the `TensorOp.foo_op` defined under
[burn-tensor/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/4ca3e31601228952bb1c1492bc9cd2adf15b5cf1/burn-tensor/src/tensor/ops/tensor.rs#L873) [crates/burn-tensor/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/1235b06e25e39a6ee5a4ac59f7f1d3da2ddb9bc3/crates/burn-tensor/src/tensor/ops/tensor.rs)

View File

@ -1,3 +1,6 @@
# Getting Started # Getting Started
This section is for setting up the environment and how to do basic development tasks such as running tests and checking your code before committing. If you need help with the process or run into issues, feel free to ask in the [discord server](https://discord.gg/uPEBbYYDB6) This section is for setting up the environment and how to do basic development tasks such as running
tests and checking your code before committing. If you need help with the process or run into
issues, feel free to ask on the [Discord server](https://discord.gg/uPEBbYYDB6) in the Development
channels.

View File

@ -1,28 +1,36 @@
# Configuring your editor # Configuring your editor
These are not required, and most of this isn't specific to Burn, but it's definitely helpful if you These steps are not required, and most of this isn't specific to Burn, but it's definitely helpful
haven't already done it. if you haven't already done it.
## VSCode ## VSCode
Install the following extensions: Install the following extensions:
- [rust-lang.rust-analyzer](https://marketplace.visualstudio.com/items?itemName=rust-lang.rust-analyzer) - [rust-lang.rust-analyzer](https://marketplace.visualstudio.com/items?itemName=rust-lang.rust-analyzer)
for Rust syntax and semantic analysis
- [tamasfe.even-better-toml](https://marketplace.visualstudio.com/items?itemName=tamasfe.even-better-toml) - [tamasfe.even-better-toml](https://marketplace.visualstudio.com/items?itemName=tamasfe.even-better-toml)
- [serayuzgur.crates](https://marketplace.visualstudio.com/items?itemName=serayuzgur.crates) for TOML syntax and semantic analysic
- [vadimcn.vscode-lldb](https://marketplace.visualstudio.com/items?itemName=vadimcn.vscode-lldb) - [serayuzgur.crates](https://marketplace.visualstudio.com/items?itemName=serayuzgur.crates) for
managing dependencies
- [vadimcn.vscode-lldb](https://marketplace.visualstudio.com/items?itemName=vadimcn.vscode-lldb) for
debugging
### Setting up the Debugger ### Setting up the Debugger
To use the debugger, follow these steps: To use the debugger, follow these steps:
1. Open `Command Palette` with `Ctrl+Shift+P` or `F1` and type `LLDB: Generate Launch Configurations from Cargo.toml` then select it, this will generate a file that should be saved as `.vscode/launch.json`.
2. Select the configuration from the "run and debug" side panel (it have a infested play button), then select the target from
3. Now you can enable breakpoint on code through IDE and then start debugging the library/binary you want, such as the following example:
1. Open `Command Palette` with `Ctrl+Shift+P` or `F1` and type
`LLDB: Generate Launch Configurations from Cargo.toml` then select it, this will generate a file
that should be saved as `.vscode/launch.json`.
2. Select the configuration from the "run and debug" side panel, then select the target from the
list.
3. Now you can enable breakpoints on code through IDE then start debugging the library/binary you
want, like in the following example:
![debug-options](debug-options-vscode.png) ![debug-options](debug-options-vscode.png)
If you're creating a new library or binary, keep in mind to repeat step 1 to always keep a fresh
If you're creating a new library or binary, keep in mind to repeat the step 1. to always keep a fresh list of targets. list of targets.
## Have another editor? Open a PR! ## Have another editor? Open a PR!

View File

@ -1,41 +1,47 @@
# Setting up the environment # Setting up the environment
There are a couple of tools that need to be installed, and commands to be familiar with, depending Depending on what part of the project you plan on contributing to, there are a couple of tools to
on what part of the project you plan on contributing to. This section should be up to date with install and commands to be familiar with. This section should be up to date with current project
current project practices (as of 2024-01-26) practices (as of 2024-04-15).
## General ## General
There are a few commands you want to run prior to any commit for a non-draft PR: There are a few commands you will want to run prior to any commit for a non-draft PR:
1. `cargo clippy --fix --allow-dirty`, this will run clippy and fix any issues it can, the allow 1. `cargo fmt --all` will run `rustfmt` on all files in the project.
dirty flag is required whenever you have uncommitted changes 2. `cargo clippy --fix` will run [Clippy](https://github.com/rust-lang/rust-clippy) and fix any
2. `cargo fmt --all`, this will run rustfmt on all files in the project coding issues it can. Clippy necessitates to be in a clean Git state, but this can be
3. `./run_checks.sh all`, this is a script located in the project root that builds and tests the circumvented by adding the `--allow-dirty` flag.
project. It is required that this passes prior to merging a PR. Fair warning, running these tests 3. `cargo xtask run-checks all` is a script located in the project root that builds and tests the
can take a while[^linux_mem_note]. project. It is required to run successfully prior to merging a PR. Fair warning, running these
tests can take a while[^linux_mem_note].
## Updating the burn semver version ## Updating the burn semver version
If for some reason you need to bump for the next version (though that should probably be left to the maintainers), edit the semantic version number in `burn/Cargo.toml`, and then run If for some reason you need to bump for the next version (though that should probably be left to the
`cargo update` to update the lock file. maintainers), edit the semantic version number in `burn/Cargo.toml`, and then run `cargo update` to
update the lock file.
## Contributing to either the Burn Book or Contributor Book ## Contributing to either the Burn Book or Contributor Book
Both the Burn Book and the Contributor Book are built with mdbook. To open the book locally, run
`mdbook serve <path/to/book>` or `cargo xtask books {burn|contributor} open` which will install and
use mdbook automatically.
Both the Burn Book and the Contributor Book are built with mdbook. If in the process of adding or modifying a page in the books, if you need to inspect the generated output(such as when using mathjax which seems prone to breakage), run use `mdbook --open <path/to/book>` or run `cargo xtask books {burn|contributor} open` which will install and use mdbook automatically. Alternatively, if you want to install mdbook directly, run the following command[^update_note]:
Alternatively, if you want to install mdbook directly, run the
following command[^update_note]:
```bash ```bash
cargo install mdbook cargo install mdbook
``` ```
Also instead of running `./run_checks.sh all`, you can run `./run_checks.sh typo`, or `cargo xtasks run-checks typo`, to only check for Also instead of running `cargo xtask run-checks all`, you can run `cargo xtask run-checks typos` to
misspellings. This will install [typo](https://crates.io/crates/typos-cli), and if any are only check for misspellings. This will install [typo](https://crates.io/crates/typos-cli), and if
encountered you should be able to run `typo -w /path/to/book` to fix them. any are encountered you should be able to run `typo -w /path/to/book` to fix them.
[^linux_mem_note]: If your system is running into issues with memory and you are on linux, you may want to switch to a [virtual console](https://wiki.archlinux.org/title/Linux_console#Virtual_consoles) to run the tests. To do this, press `ctrl+alt+f3` to switch to a virtual console (and log in), and either `ctrl+alt+f2` or `ctrl+alt+f1` to switch back to your graphical session. [^linux_mem_note]: If your system is running into issues with memory and you are on linux, you may want to switch
to a [virtual console](https://wiki.archlinux.org/title/Linux_console#Virtual_consoles) to run
the tests. To do this, press `ctrl+alt+f3` to switch to a virtual console (and log in), and
either `ctrl+alt+f1` or `ctrl+alt+f2` to switch back to your graphical session.
[^update_note]: You might also want to install [cargo-update](https://github.com/nabijaczleweli/cargo-update) to easily keep your tools up to date, though it is in no way required. [^update_note]: You might also want to install [cargo-update](https://github.com/nabijaczleweli/cargo-update) to
easily keep your tools up to date, though it is in no way required.

View File

@ -1,32 +1,30 @@
# Testing # Testing
## Test for TensorOps ## Test for Tensor Operations
The following examples use matrix multiplication operation. Test for tensor operations (generally of the form: given this input, expect it match or approximate
this output) are defined only in
Test for Tensor operations (as in given this input, expect it match or approximate this output) are [`crates/burn-tensor/src/test/ops`](https://github.com/tracel-ai/burn/tree/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-tensor/src/tests/ops)
defined only in and not in the backends (with the exception of `burn-autodiff`). The tensor operation tests are
[`burn-tensor/src/test/ops`](https://github.com/tracel-ai/burn/blob/b9bd42959b0d3e755a25e383cb5b38beb25559b8/burn-tensor/src/tests/ops/matmul.rs#L1) added to the `testgen_all` macro rule in
and not in the backends, with the exception of `burn-autodiff`. These tests are added to the [`crates/burn-tensor/src/tests/mod.rs`](https://github.com/tracel-ai/burn/blob/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-tensor/src/tests/mod.rs).
`testgen_all` macro rule in
[`burn-tensor/src/test/mod.rs`](https://github.com/tracel-ai/burn/blob/b9bd42959b0d3e755a25e383cb5b38beb25559b8/burn-tensor/src/tests/mod.rs#L59).
This is then propagated to the existing backends without any additional work. This is then propagated to the existing backends without any additional work.
### Test for Autodiff ### Test for Autodiff
The following examples use the power operation.
Tests for autodiff go under Tests for autodiff go under
[burn-autodiff/src/tests/{op_name}.rs](https://github.com/tracel-ai/burn/blob/4ca3e31601228952bb1c1492bc9cd2adf15b5cf1/burn-autodiff/src/tests/pow.rs#L31) [burn-autodiff/src/tests](https://github.com/tracel-ai/burn/tree/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-autodiff/src/tests)
(replace `op_name` with whatever makes sense for your op), and for tensor operations both the left and and should verify backward pass correctness. For binary tensor operations, both the left and right
right sides need to be verified. The easiest way to do this, is to: sides need to be verified.
1. use small tensors with simple values Here's an easy way to define tests for a new operation's backward pass:
2. pop open a terminal, launch `ipython` and import `numpy` then do the calculations by hand. You
can also use [google colab](https://colab.google/) if you prefer so that you don't have to
install the packages on your system.
3. compare the actual output to the expected output for lhs, rhs and regular operation
Generally, it seems preferable to use 1. Use small tensors with simple values.
`actual_output_tensor.into_data().assert_approx_eq(&expected_tensor_data,3)` to `assert_eq!(...` due 2. Pop open a terminal, launch `ipython` and import `numpy` then do the calculations by hand. You
to occasional hiccups with floating point calculations. can also use [Google Colab](https://colab.google/) so you don't have to install the packages on
your system.
3. Compare the actual outputs to the expected output for left-hand side, right-hand side.
For float tensors, it is advised to use
`actual_output_tensor.into_data().assert_approx_eq(&expected_tensor_data,3)` instead of
`assert_eq!(...` due to occasional hiccups with floating point calculations.

View File

@ -1,3 +1,3 @@
# Guides for Contributors # Guides for Contributors
The following guides are meant to help contributors trying to accomplish specific tasks, such as adding new operations to Burn or generating test models for `burn-import`. The following guides are meant to help contributors accomplish specific tasks, such as adding new operations to Burn or generating test models for `burn-import`.

View File

@ -1,49 +1,49 @@
# Adding a new operation to burn # Adding a New Operation to burn
Let's discuss how one might go about adding new operators to Burn, using the example of the pow Let's discuss how one might go about adding new operators to Burn, using the example of the pow
operator added in [this PR](https://github.com/tracel-ai/burn/pull/1133/files). In that PR, the operator added in [this PR](https://github.com/tracel-ai/burn/pull/1133/files).
following things took place (albeit not in this order).
## Adding the Op to burn-tensor ## Adding the Op to burn-tensor
`burn-tensor` is the crate that defines all tensor operations that need to be implemented by the `burn-tensor` is the crate that defines all tensor operations that need to be implemented by the
various backends. The core of this lies in `crates/burn-tensor/src/tensor/api/numeric.rs`, which is various backends. The core of this lies in
home to the numeric trait and its implementation for the different tensor types. The numeric trait [crates/burn-tensor/src/tensor/api/numeric.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/numeric.rs),
is the home of all tensor operations that are numeric in nature and that are shared by `Int` and which is home to the numeric trait and its implementation for the different tensor types. The
`Float` Tensor types. More information on the relationship between Tensor modules can be found under numeric trait is the home of all tensor operations that are numeric in nature and that are shared by
the section for [Tensor Architecture](../project-architecture/Tensor.md#tensorops). `Int` and `Float` Tensor types. More information on the relationship between Tensor modules can be
found under the section for [Tensor Architecture](../project-architecture/Tensor.md#tensorops).
Here is where pow was added to `crates/burn-tensor/src/tensor/api/numeric.rs`: Here is where pow was added to `crates/burn-tensor/src/tensor/api/numeric.rs`:
1. for the 1. for the
[`Tensor<Backend,Dimension,Kind>` struct](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tensor/src/tensor/api/numeric.rs#L565) [`Tensor<Backend, Dimension, Kind>` struct](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/numeric.rs#L573)
2. for the 2. for the
[numeric trait](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tensor/src/tensor/api/numeric.rs#L1922) [numeric trait](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/numeric.rs#L1955)
3. for the implementation of numeric for 3. for the implementation of numeric for
[float](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tensor/src/tensor/api/numeric.rs#L2677) [float](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/numeric.rs#L2722)
and and
[int](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tensor/src/tensor/api/numeric.rs#L2336) [int](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/numeric.rs#L2375)
Tensor is a struct that has a single member: `primitive` (defined Tensor is a struct that has a single member: `primitive` (defined
[here](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tensor/src/tensor/api/base.rs#L27)), [here](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/base.rs#L27)),
that is defined by it's that is defined by its
[`Kind`](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tensor/src/tensor/api/kind.rs#L16): [`Kind`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/kind.rs#L16):
one of `Bool`, `Float`, or `Int` (those linked in 3). These call the ops for that data type defined one of `Bool`, `Float`, or `Int` (those linked in 3). These call the ops for that data type defined
in the in the
[`Backend`](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tensor/src/tensor/backend/base.rs#L54) [`Backend`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/backend/base.rs#L54)
supertrait[^supertrait]. This is the trait that is then implemented by the different `burn-` supertrait[^supertrait]. This is the trait that is then implemented by the different `burn-`
backends (such as `burn-ndarray` and `burn-wgpu`) which implement the functions if no default is backends (such as `burn-ndarray` and `burn-wgpu`) which must implement the functions if no default
provided. is provided.
In this case, we don't need to worry about `Bool` Tensors. Ops for `Float` is implemented under In this case, we don't need to worry about `Bool` Tensors. Ops for `Float` is implemented under
[burn-tensor/src/tensor/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tensor/src/tensor/ops/tensor.rs#L977), [`crates/burn-tensor/src/tensor/ops/tensor.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/ops/tensor.rs#L991),
and for `Int` under and for `Int` under
[`burn-tensor/src/tensor/ops/int_tensor.rs`](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tensor/src/tensor/ops/int_tensor.rs#L539). [`crates/burn-tensor/src/tensor/ops/int_tensor.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/ops/int_tensor.rs#L539).
The current convention is ops of each type, if not unique to that type, are prefixed with the type. The current convention is ops of each type, if not unique to that type, are prefixed with the type.
So `powf` and sundry would be defined as `int_powf` for `IntTensorOps` and `float_powf` for So `powf` and sundry would be defined as `int_powf` for `IntTensorOps` and `float_powf` for
`FloatTensorOps`. If an op is unique to a type, then it should be implemented under `FloatTensorOps`. If an op is unique to a type, then it should be implemented under
`burn-tensor/src/api/{type}.rs`. For example, here is an implementation for `burn-tensor/src/api/{type}.rs`. For example, here is an implementation for
[`sin` under `burn-tensor/src/api/float.rs`](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tensor/src/tensor/api/float.rs#L82) [`sin` under `crates/burn-tensor/src/api/float.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/float.rs#L82)
which obviously doesn't make sense for `Int` or `Bool` tensors. which obviously doesn't make sense for `Int` or `Bool` tensors.
The `Int` Tensor function uses the ones defined for Float with 2 extra casts (LHS to a `Float` The `Int` Tensor function uses the ones defined for Float with 2 extra casts (LHS to a `Float`
@ -53,37 +53,53 @@ implementations.
### Adding Tests ### Adding Tests
Additional Tests should be added to `burn-tensor` under Additional Tests should be added to `burn-tensor` under
[`crates/burn-tensor/src/tests/ops/{op_name}.rs`](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tensor/src/tests/ops/powf.rs#L1), [`crates/burn-tensor/src/tests/ops/{op_name}.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tests/ops/powf.rs#L1),
inserting the module name into `crates/burn-tensor/src/tests/ops/mod.rs`. Then add it to the inserting the module name into `crates/burn-tensor/src/tests/ops/mod.rs`. Then add it to the
`testgen_all` macro under `crates/burn-tensor/src/tests/mod.rs`. This macro is called from the `testgen_all` macro under `crates/burn-tensor/src/tests/mod.rs`. This macro is called from the
`lib.rs` file in each backend, which autogenerates the tests for that specific backend. It isn't `lib.rs` file in each backend, which autogenerates the tests for that specific backend. It isn't
necessary to define tests in the backends directly, save for those that require specific testing necessary to define tests in the backends directly, save for those that require specific testing
such as`burn-autodiff` such as `burn-autodiff`.
## Adding the Op to burn-autodiff ## Adding the Op to burn-autodiff
Since this is probably the hardest and the least straightforward, we'll cover this backend Since this is probably the hardest and the least straightforward, we'll cover this backend
separately. Burn-autodiff enables other backends to use autodifferentiation[^autodiff]. Ops for separately. `burn-autodiff` enables other backends to use autodifferentiation[^autodiff]. Ops for
float types are implemented in float types are implemented in
[crates/burn-autodiff/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-autodiff/src/ops/tensor.rs#L2172) [crates/burn-autodiff/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-autodiff/src/ops/tensor.rs)
and need to: and need to:
1. define a unit struct [^absolute_units] that implements a backward (pass) function 1. Define a unit struct [^absolute_units] that implements a backward (pass) function
2. Within the backward function, as this is an elementwise binary operation it implements the binary 2. Within the backward function, as this is an elementwise binary operation it implements the binary
function (from backward.rs under the same directory), the last 2 arguments are two closures that function (from `backward.rs` under the same directory), the last 2 arguments are two closures
define the left and right partial derivatives. that define the left and right partial derivatives.
3. Then defines what happens when a specific operation is tracked or untracked, where untracked just 3. Then define what happens when a specific operation is tracked or untracked, where untracked just
calls the function in the normal way, and tracked executes the backward function defined above calls the function in the normal way, and tracked sets the execution the backward function
defined above.
4. When tracked, operations are part of the autodiff graph and must save the needed information to
efficiently perform their backward pass later. If the information is light (such as a shape), it
should be directly saved in the state. If the operation's inputs are needed to compute the
backward pass, it should be checkpointed rather than saved. This will allow the input to be
provided lazily at the backward pass depending on the checkpointing strategy.
5. An operation must also be identified as _compute-bound_ (`.computeBound()`) or _memory-bound_
(`.memoryBound()`) for gradient checkpointing. _Compute-bound_ operation are heavy to compute
(for instance matmul or convolution), which means that even with checkpointing they will save
their output for the backward pass and not recompute it. _Memory-bound_ operations are more
trivial (like `powf` which only performs one small operation per tensor entry), so it can be
beneficial to recompute them during the backward pass instead of saving their whole forward
output to memory. Operations registered as _memory-bound_ need to know their parents
(`.parents()` method) and how to recompute their forward pass during the backward pass (with a
struct that implements `RetroForward`), using their parents' outputs.
Steps 1 and 3 are boilerplate, so much so that you can probably just copy the contents of another op The above steps are mostly boilerplate, so you can often just copy the contents of another similar
of the same type (binary, unary) and change the name of the struct, and ensure that either both op, change the name of the structs, and ensure that either both sides have the data they need (if
sides have the data they need (if they need to have a copy of the opposite sided tensor, clone its they need to have a copy of the opposite sided tensor, clone its contents).
contents).
### Computing derivatives
For those that need it, here is a quick refresher on the necessary calculus. If you are familiar For those that need it, here is a quick refresher on the necessary calculus. If you are familiar
with how to calculate partial derivatives, you can skip this section. with how to calculate partial derivatives, you can skip this section.
Since pow is a binary operation, the left and right functions are the partial derivatives with Since `pow` is a binary operation, the left and right functions are the partial derivatives with
respect to the left and right sided tensors. respect to the left and right sided tensors.
Let's define the operator as a function \\(f(x,y)=x^{y}\\) , where \\(x\\) is the left hand tensor Let's define the operator as a function \\(f(x,y)=x^{y}\\) , where \\(x\\) is the left hand tensor
@ -94,28 +110,14 @@ $$\frac{\delta }{\delta x} (x^{y})= y \cdot x^{y-1}$$ is the left handed closure
$$\frac{\delta }{\delta y} (x^{y}) = x^{y} \cdot ln(x)$$ $$\frac{\delta }{\delta y} (x^{y}) = x^{y} \cdot ln(x)$$
is the right. If you aren't sure how to calculate these by hand, I recommend using is the right. If you aren't sure how to calculate these by hand, it is recommended to use
[symbolab](<https://www.symbolab.com/solver/partial-derivative-calculator/%5Cfrac%7B%5Cpartial%7D%7B%5Cpartial%20x%7D%5Cleft(x%5E%7By%7D%5Cright)?or=input>), [symbolab](<https://www.symbolab.com/solver/partial-derivative-calculator/%5Cfrac%7B%5Cpartial%7D%7B%5Cpartial%20x%7D%5Cleft(x%5E%7By%7D%5Cright)?or=input>),
plug in your operator in terms of \\(x\\) and \\(y\\), and just swap out the variable plug in your operator in terms of \\(x\\) and \\(y\\), and just swap out the variable
\\(x\\)|\\(y\\) in the partial derivative to get the other side. \\(x\\)|\\(y\\) in the partial derivative to get the other side.
### Testing autodiff ### Testing autodiff
Test for autodiff go under For testing the `autodiff` operations, please refer to [this section](../getting-started/testing.md).
[burn-autodiff/src/tests/{op_name}.rs](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-autodiff/src/tests/pow.rs#L31)
(replacing `op_name` with whatever makes sense for your op), and for tensor operations both the left
and right side need to be verified. The easiest way to do this, is to
1. use small tensors with simple values
2. Compute the expected results for the chosen tensors, using some independent and reliable tool.
For instance, you can pop open a terminal and launch `ipython` import `numpy` (or just use
[google colab](https://colab.google/) if you don't have the packages installed and don't want to
install them), and do the calculations by hand.
3. comparing the actual to expected output for lhs, rhs and regular operation
generally, it seems preferable to use
`actual_output_tensor.to_data().assert_approx_eq(&expected_tensor_data,3)` to `assert_eq!(...` due
to occasional hiccups with floating point calculations.
## Adding the Op to other backends ## Adding the Op to other backends
@ -123,36 +125,36 @@ Most of these are fairly straightforward implementations. For reference here's p
implementation for torch, ndarray and candle backends: implementation for torch, ndarray and candle backends:
1. Torch implementation in 1. Torch implementation in
[crates/burn-tch/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tch/src/ops/tensor.rs#L458) [crates/burn-tch/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tch/src/ops/tensor.rs#L467)
and the Op used in and the Op used in
[crates/burn-tch/src/ops/base.rs](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tch/src/ops/base.rs#L481) [crates/burn-tch/src/ops/base.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tch/src/ops/base.rs#L481)
2. NdArray in 2. NdArray in
[crates/burn-ndarray/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-ndarray/src/ops/tensor.rs#L465) [crates/burn-ndarray/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-ndarray/src/ops/tensor.rs#L472)
3. Candle in 3. Candle in
[crates/burn-candle/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-candle/src/ops/tensor.rs#L492) [crates/burn-candle/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-candle/src/ops/tensor.rs#L504)
This is where any calculation happens currently. Playing a guessing game with method names and This is where any calculation happens currently. Playing a guessing game with method names and
seeing what completions are suggested will take you far. If you are having trouble figuring out how seeing what completions are suggested will take you far. If you are having trouble figuring out how
to do it from the docs for that backend, to do it from the docs for that backend,
[try searching github for relevant function calls](https://docs.github.com/en/search-github/github-code-search/understanding-github-code-search-syntax). [try searching github for relevant function calls](https://docs.github.com/en/search-github/github-code-search/understanding-github-code-search-syntax).
## Adding the Op to fusion, jit, and wgpu backends ## Adding the Op to fusion, JIT and wgpu backends
Adding an operator to these backends can be fairly straightforward, though due to what these Adding an operator to these backends can be fairly straightforward, though due to what these
backends are for, involves a bit more indirection. Fusion and jit, like autodiff, are not target backends are for, involves a bit more indirection. Fusion and jit, like autodiff, are not target
backends as much as backends that enable certain functionality for other backends, in this case backends as much as backends that enable certain functionality for other backends, in this case
kernel fusion (which is currently only supported for `burn-wgpu`) or just-in-time compilation. kernel fusion or just-in-time compilation (only available for `burn-wgpu` backend at the moment).
Adding the operator won't involve doing any calculation, you'll just be describing how the generated Adding the operator won't involve doing any calculation, you'll just be describing how the generated
code should look. Most of this can be copy/pasted/adjusted from other functions. code should look. Most of this can be copy/pasted/adjusted from other functions.
Here's how powf was added to burn fusion: Here's how powf was added to `burn-fusion`:
1. added powf to the float ops under 1. Added powf to the float ops under
[`crates/burn-fusion/src/ops/float.rs`](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-fusion/src/ops/float.rs#L1813) [`crates/burn-fusion/src/ops/float.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-fusion/src/ops/float.rs#L1838)
2. added powf to the `NumericOperationDescription` enum under 2. Added powf to the `NumericOperationDescription` enum under
[crates/burn-fusion/src/stream/operation.rs](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-fusion/src/stream/operation.rs#L426) [crates/burn-fusion/src/stream/operation.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-fusion/src/stream/operation.rs#L433)
3. added powf to the implementations of `NumericOperationDescription` enum under 3. Added powf to the implementations of `NumericOperationDescription` enum under
[crates/burn-fusion/src/stream/context.rs](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-fusion/src/stream/context.rs#L764) [crates/burn-fusion/src/stream/context.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-fusion/src/stream/context.rs#L771)
The way wgpu handles tensor-scalar operations is by transforming both into a sequence of vectorized The way wgpu handles tensor-scalar operations is by transforming both into a sequence of vectorized
scalar operations. Since powf already existed in burn-wgpu, it was pretty easy to reuse the existing scalar operations. Since powf already existed in burn-wgpu, it was pretty easy to reuse the existing
@ -163,70 +165,70 @@ implementation is defined in `burn-jit`.
Here is where code was added for powf in `burn-jit` and `burn-wgpu`: Here is where code was added for powf in `burn-jit` and `burn-wgpu`:
1. to the implementation of 1. to the implementation of
[`FloatTensorOps` under `crates/burn-jit/src/ops/float_ops.rs`](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-jit/src/ops/float_ops.rs#L491) [`FloatTensorOps` under `crates/burn-jit/src/ops/float_ops.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-jit/src/ops/float_ops.rs#L491)
2. the function being called was added to 2. the function being called was added to
[crates/burn-jit/src/ops/numeric.rs](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-jit/src/ops/numeric.rs#L208) [crates/burn-jit/src/ops/numeric.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-jit/src/ops/numeric.rs#L229)
3. the operator was defined in 3. the operator was defined in
[`crates/burn-jit/src/codegen/dialect/gpu/operation.rs`](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-jit/src/codegen/dialect/gpu/operation.rs#L37) [`crates/burn-jit/src/codegen/dialect/gpu/operation.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-jit/src/codegen/dialect/gpu/operation.rs#L37)
4. the vectorization was added to 4. the vectorization was added to
[`crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs`](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs#L55) [`crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs#L55)
5. how the operation looks to the gpu was added to 5. how the operation looks to the gpu was added to
[`crates/burn-jit/src/fusion/tracing/builder.rs`](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-jit/src/fusion/tracing/builder.rs#L279) [`crates/burn-jit/src/fusion/tracing/builder.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-jit/src/fusion/tracing/builder.rs#L279)
6. the mapping between the gpu operation and the WGSL instruction was added to 6. the mapping between the gpu operation and the WGSL instruction was added to
[`crates/burn-wgpu/src/compiler/wgsl/compiler.rs`](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-wgpu/src/compiler/wgsl/compiler.rs#L455) [`crates/burn-wgpu/src/compiler/wgsl/compiler.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-wgpu/src/compiler/wgsl/compiler.rs#L455)
7. the WGSL instruction itself was added to the 7. the WGSL instruction itself was added to the
[instruction op enum in `crates/burn-wgpu/src/compiler/wgsl/instructions.rs`](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-wgpu/src/compiler/wgsl/instructions.rs#L103), [instruction op enum in `crates/burn-wgpu/src/compiler/wgsl/instructions.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-wgpu/src/compiler/wgsl/instructions.rs#L103),
and the actual and the actual
[instruction in wgsl here](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-wgpu/src/compiler/wgsl/instructions.rs#L273) [instruction in wgsl here](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-wgpu/src/compiler/wgsl/instructions.rs#L273)
We needed to generate some custom WGSL code for powf, primarily due to issues with proper case We needed to generate some custom WGSL code for powf, primarily due to issues with proper case
handling of the wgsl pow function, like 0 to the 0 power being 1, and any negative number to an even handling of the wgsl pow function, like 0 to the 0 power being 1, and any negative number to an even
power being positive. We reused as much as the existing logic as possible, and then branched at the power being positive. We reused as much as the existing logic as possible, and then branched at the
last point based off the var type of the rhs. last point based off the var type of the rhs.
[See here](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-wgpu/src/compiler/wgsl/compiler.rs#L596). [See here](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-wgpu/src/compiler/wgsl/compiler.rs#L596).
For most operations, you shouldn't need to add to `crates/burn-wgpu/src/compiler/wgsl/extension.rs` For most operations, you shouldn't need to add to `crates/burn-wgpu/src/compiler/wgsl/extension.rs`
unless the operation isn't native to WGSL. unless the operation isn't native to WGSL.
For functions that need a complex kernel without a direct mapping to a base instruction, it is not
as straightforward. An easier manner of implementing them is underway.
## Adding the Op to burn-import ## Adding the Op to burn-import
I won't comment on generating the ONNX test files or the tests, as this is already covered Generating the ONNX test files or tests is already covered
[in the ONNX to burn guide](onnx-to-burn-conversion-tool.md#adding-new-operators), this is more [in the ONNX to burn guide](onnx-to-burn-conversion-tool.md#adding-new-operators); this is more
about the specific changes you need to make when adding new operators after you have generated the about the specific changes you need to make when adding new operators after you have generated the
tests. tests.
The crate is divided into two sections `src/burn` and `src/onnx`. The code under the former The crate is divided into two sections `src/burn` and `src/onnx`. The code under the former
corresponds to the operation you've implemented earlier in this guide, and the latter to the corresponds to the operation you've implemented earlier in this guide, and the latter to the
operations defined in the ONNX specification. So when you are loading a model, the operator is first operations defined in the ONNX specification. So when you are loading a model, the operator is first
parsed to an intermediate representation defined by `src/onnx`, and then mapped to a Burn operations parsed to an intermediate representation defined by `src/onnx`, and then mapped to a Burn operation
defined under `src/burn/node`. defined under `src/burn/node`.
Let's review the changes made for pow starting from `src/burn` and moving to `src/onnx`: Let's review the changes made for powf starting from `src/burn` and moving to `src/onnx`:
1. determine the type of operator and add your operator to the appropriate node (operation) type, in 1. Determine the type of operator and add your operator to the appropriate node (operation) type, in
this case this case
[BinaryNode under crates/burn-import/src/burn/node/binary.rs](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-import/src/burn/node/binary.rs#L160) [BinaryNode under `crates/burn-import/src/burn/node/binary.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/burn/node/binary.rs#L160)
along with its along with its
[`to_str` definition](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-import/src/burn/node/binary.rs#L15) [`to_str` definition](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/burn/node/binary.rs#L15)
2. add an arm to the match statement inside the `into_burn` function in 2. Add an arm to the match statement inside the `into_burn` function in
[crates/burn-import/src/onnx/to_burn.rs](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-import/src/onnx/to_burn.rs#L268) [crates/burn-import/src/onnx/to_burn.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/onnx/to_burn.rs#L272)
for the ONNX `NodeType` (which corresponds to an op in the ONNX spec), and make an for the ONNX `NodeType` (which corresponds to an op in the ONNX spec), and make an
[`{op}_conversion` function](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-import/src/onnx/to_burn.rs#L682) [`{op}_conversion` function](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/onnx/to_burn.rs#L717)
that maps the ONNX node to the binary type that maps the ONNX node to the binary type
3. specify how dimensions for the output should be derived in 3. Specify how dimensions for the output should be derived in
[crates/burn-import/src/onnx/dim_inference.rs](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-import/src/onnx/dim_inference.rs#L53) [crates/burn-import/src/onnx/dim_inference.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/onnx/dim_inference.rs#L55)
And you're done! Congrats, you just fully added a new op to burn, and we are all one step closer to And you're done! Congrats, you just fully added a new operation to burn, and we are all one step
the answer to [are we learning yet?](https://www.arewelearningyet.com/) being "Yes, and it's closer to the answer to [Are we learning yet?](https://www.arewelearningyet.com/) being "Yes, and
freaking fast!". Buy yourself a coffee. it's freaking fast!". Buy yourself a coffee.
[^supertrait]: [^supertrait]: for more on supertraits see
for more on supertraits see
[the advanced trait section of the rust book](https://doc.rust-lang.org/book/ch19-03-advanced-traits.html#using-supertraits-to-require-one-traits-functionality-within-another-trait) [the advanced trait section of the rust book](https://doc.rust-lang.org/book/ch19-03-advanced-traits.html#using-supertraits-to-require-one-traits-functionality-within-another-trait)
[^autodiff]: [^autodiff]: wiki link for
wiki link for
[automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
[^absolute_units]: [^absolute_units]: for more information on unit structs see
for more information on unit structs see
[the defining and instantiating structs section of the rust book](https://doc.rust-lang.org/book/ch05-01-defining-structs.html#unit-like-structs-without-any-fields) [the defining and instantiating structs section of the rust book](https://doc.rust-lang.org/book/ch05-01-defining-structs.html#unit-like-structs-without-any-fields)

View File

@ -1 +0,0 @@
# Frequently Encountered Issues

View File

@ -1 +0,0 @@
# Issues Related To Adding Operators

View File

@ -5,14 +5,19 @@ ONNX to Burn conversion tool. This tool allows the importation of ONNX models in
learning framework written in Rust. It converts both ONNX models to Rust source code and model learning framework written in Rust. It converts both ONNX models to Rust source code and model
weights to Burn state files. weights to Burn state files.
For an introduction to ONNX import in Burn, see
[this section of the Burn book](https://burn.dev/book/import/onnx-model.html).
## Table of Contents ## Table of Contents
1. [Design Overview](#Design-Overview) - [ONNX to Burn Conversion Tool: Development Guide](#onnx-to-burn-conversion-tool-development-guide)
1. [Design Goals](#Design-Goals) - [Table of Contents](#table-of-contents)
2. [Design Decisions](#Design-Decisions) - [Design Overview](#design-overview)
2. [Adding New Operators](#Adding-New-Operators) - [Design Goals](#design-goals)
3. [Testing](#Testing) - [Design Decisions](#design-decisions)
4. [Resources](#Resources) - [Adding New Operators](#adding-new-operators)
- [Testing](#testing)
- [Resources](#resources)
## Design Overview ## Design Overview
@ -41,15 +46,17 @@ The conversion process involves three main stages:
To extend `burn-import` with support for new ONNX operators, follow these steps: To extend `burn-import` with support for new ONNX operators, follow these steps:
1. **Create PyTorch Script**: Place a PyTorch script using the new operator under 1. **Create PyTorch Script**: Place a PyTorch script using the new operator under
`./burn-import/onnx-tests/tests/<op>/<op>.py`. Make sure to print both input and output tensors `crates/burn-import/onnx-tests/tests/<op>/<op>.py`. Make sure to print both input and output
for end-to-end testing. tensors for end-to-end testing.
2. **Generate ONNX Model**: Run the PyTorch script to produce an ONNX model. 2. **Generate ONNX Model**: Run the PyTorch script to produce an ONNX model.
3. **Visualize ONNX Model**: Use [Netron](https://github.com/lutzroeder/netron) to verify the ONNX 3. **Visualize ONNX Model**: Use [Netron](https://github.com/lutzroeder/netron) to verify the ONNX
model contains the expected operators. model contains the expected operators.
4. **Generate IR and Burn Graph**: Navigate to `./burn-import/` and run: 4. **Generate IR and Burn Graph**: Navigate to
[crates/burn-import/](https://github.com/tracel-ai/burn/tree/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-import)
and run:
``` ```
cargo r -- ./onnx-tests/tests/<op>/<op>.onnx ./out cargo r -- ./onnx-tests/tests/<op>/<op>.onnx ./out
@ -61,8 +68,10 @@ To extend `burn-import` with support for new ONNX operators, follow these steps:
6. **Inspect Generated Files**: The `my-model.graph.txt` contains IR details, `my-model.rs` holds 6. **Inspect Generated Files**: The `my-model.graph.txt` contains IR details, `my-model.rs` holds
the Burn model in Rust code, and `my-model.json` includes the model data. the Burn model in Rust code, and `my-model.json` includes the model data.
7. **Add End-to-End Test**: Include the test in `./burn-import/onnx-tests/tests/onnx_tests.rs`. 7. **Add End-to-End Test**: Include the test in
Further details can be found in the [onnx-tests README](./burn-import/onnx-tests/README.md). [crates/burn-import/onnx-tests/tests/onnx_tests.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-import/onnx-tests/tests/onnx_tests.rs).
Further details can be found in the
[onnx-tests README](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-import/onnx-tests/README.md).
## Testing ## Testing

View File

@ -1,22 +1,22 @@
# How to read this book # How to read this book
Throughout this book, we try to keep the following structure Throughout this book, we maintain the following structure.
## Linking ## Linking
When referring to structures or functions within codebase, we provide permalinks to the lines in When referring to structures or functions within codebase, we provide permalinks to the lines in
specific commits, and indicate them by the relative path of their parent file from the project root. specific commits, and indicate them by the relative path of their parent file from the project root.
For example this is a reference to the `Tensor` struct in For example this is a reference to the `Tensor` struct in
[`burn-tensor/src/tensor/api/base.rs`](https://github.com/tracel-ai/burn/blob/b9bd42959b0d3e755a25e383cb5b38beb25559b8/burn-tensor/src/tensor/api/base.rs#L23) [`crates/burn-tensor/src/tensor/api/base.rs`](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tensor/src/tensor/api/base.rs#L27)
When some reference information is useful but is beyond the scope of contributing to burn, we When some reference information is useful but is beyond the scope of contributing to Burn, we
provide that information in a footnote. To build on the previous example, the `Tensor` mentioned is provide that information in a footnote. To build on the previous example, the `Tensor` mentioned is
what's referred to as a newtype struct[^1]. what's referred to as a newtype struct[^1].
Direct hyperlinks are for tools and resources that are not part of the burn project, but are useful Direct hyperlinks are for tools and resources that are not part of the Burn project, but are useful
for contributing to it. For example, when working on implementing an op for autodiff, it is useful for contributing to it. For example, when working on implementing an operation for autodiff, it can
to use [symbolab](https://www.symbolab.com/) to calculate the left and right partial derivatives. be useful to use [symbolab](https://www.symbolab.com/) to calculate the left and right partial
derivatives.
[^1]: [^1]: For more information on newtype please refer to
for more information on newtype please refer to
[the Advanced Types chapter of the Rust Book](https://doc.rust-lang.org/book/ch19-04-advanced-types.html#using-the-newtype-pattern-for-type-safety-and-abstraction) [the Advanced Types chapter of the Rust Book](https://doc.rust-lang.org/book/ch19-04-advanced-types.html#using-the-newtype-pattern-for-type-safety-and-abstraction)

View File

@ -7,18 +7,20 @@ provide some detailed guidance on how to contribute to the project.
We have crafted some sections for you: We have crafted some sections for you:
- [Getting Started](./getting-started): Much like the Burn Book for users, we'll start with the - [Getting Started](./getting-started): Much like the [Burn Book](https://burn.dev/book/) which
fundamentals, guiding you through tasks like setting up the development environment, how to run targets users, we'll start with the fundamentals, guiding you through tasks like setting up the
tests, and what you should check prior to each commit. development environment, running tests, and what you should check prior to each commit.
- [Project Architecture](./project-architecture): This section will give you a more in-depth look at
the architecture of Burn
- [Guides](./guides): We'll provide some guides on how to do specific tasks, such as adding a new - [Project Architecture](./project-architecture): This section will give you an in-depth look at the
operation to Burn. architecture of Burn.
- [Guides](./guides): We provide some guides on how to do specific tasks, such as adding a new
operations to Burn.
- [Frequently Encountered Issues](./frequently-encountered-issues): If you are running into an issue - [Frequently Encountered Issues](./frequently-encountered-issues): If you are running into an issue
that has you stumped, this is the section to check out prior to asking on the [discord](https://discord.gg/uPEBbYYDB6). It's a that has you stumped, this is the section to check out prior to asking on the
collection of errors encountered by contributors, what caused them, and how they were resolved. [Discord](https://discord.gg/uPEBbYYDB6). It's a collection of errors encountered by contributors,
what caused them, and how they were resolved.
As this book is geared towards contributors and not towards users of Burn, we'll assume you have a As this book is geared towards contributors and not towards users of Burn, we'll assume you have a
good understanding of software development, but will make efforts to explain anything outside of good understanding of software development, but will make efforts to explain anything outside of

View File

Before

Width:  |  Height:  |  Size: 111 KiB

After

Width:  |  Height:  |  Size: 111 KiB

View File

@ -1,19 +1,19 @@
# Project Architecture # Project Architecture
This Section documents most major architectural decisions with the reasoning behind them. This section documents most major architectural decisions with the reasoning behind them.
**Sections** **Sections**
- [Module](./module.md) - [Module](./module.md)
- [Optimization](./module.md#optimization) - [Optimization](./module.md#optimization)
- [Constraints](./module.md#constraints) - [Constraints](./module.md#constraints)
- [Solution](./module.md#solution) - [Solution](./module.md#solution)
- [Serialization](./serialization.md) - [Serialization](./serialization.md)
- [Constraints](./serialization.md#constraints) - [Constraints](./serialization.md#constraints)
- [Solution](./serialization.md#solution) - [Solution](./serialization.md#solution)
- [Pros](./serialization.md#pros) - [Pros](./serialization.md#pros)
- [Cons](./serialization.md#cons) - [Cons](./serialization.md#cons)
- [Compatibility](./serialization.md#compatibility) - [Compatibility](./serialization.md#compatibility)
- [Tensor](./tensor.md) - [Tensor](./tensor.md)
- [Backend](./backend.md) - [Backend](./backend.md)
- [Autodiff](./backend.md#autodiff) - [Autodiff](./backend.md#autodiff)

View File

@ -16,47 +16,50 @@ Having one struct for tensors reduces the complexity of the tensor API, which al
duplicated documentation to write and maintain. duplicated documentation to write and maintain.
Tensors are thread-safe, which means that you can send a tensor to another thread, and everything Tensors are thread-safe, which means that you can send a tensor to another thread, and everything
will work, including auto-differentiation. Note that there are no in-place tensor operations since will work, including auto-differentiation. Note that there are no explicit in-place tensor
all tensor operations take owned tensors as parameters, which make it possible to mutate them. operations since all tensor operations take owned tensors as parameters, which make it possible to
Tensors can be shared simply by cloning them, but if there is only one reference to a tensor, the mutate them. Tensors can be shared simply by cloning them, but if there is only one reference to a
backend implementation is free to reuse the tensor's allocated data. For more information about how tensor, the backend implementation is free to reuse the tensor's allocated data. For more
it is done, you can have a look at this information about how it is done, you can have a look at this
[blog post](https://burn.dev/blog/burn-rusty-approach-to-tensor-handling). [blog post](https://burn.dev/blog/burn-rusty-approach-to-tensor-handling).
## TensorOps ## Tensor Operations
Operations on Tensors are defined in traits (generally part of the Backend Supertrait) and Operations on Tensors (sometimes shortened to Ops) are defined in traits (generally part of the
implemented for the Tensor struct. The appropriate parent trait of an op depends on the type of Backend Supertrait) and implemented for the Tensor struct. The appropriate parent trait of an
operation: operation depends on the type of operation:
- `base` => All tensor kinds should implement these operations (Reshape, into_data, etc.). The - `base` => All tensor kinds should implement these operations (reshape, into_data, etc.). The
implementation is in `burn-tensor/src/tensor/api/base.rs`. implementation is in
[crates/burn-tensor/src/tensor/api/base.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/api/base.rs).
- `numeric` => All tensors that are numeric by nature should implement these operations (Add, Sub, - `numeric` => All tensors that are numeric by nature should implement these operations (Add, Sub,
Div, etc.). The implementation is in `burn-tensor/src/tensor/api/numeric.rs`. Div, etc.). The implementation is in
[crates/burn-tensor/src/tensor/api/numeric.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/api/numeric.rs).
- `Float` => Tensor operations are only available for float tensors. The implementation is in - `Float` => Tensor operations are only available for float tensors. The implementation is in
`burn-tensor/src/tensor/api/float.rs`. [burn-tensor/src/tensor/api/float.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/api/float.rs).
- `Int` => Tensor operations are only available for int tensors. The implementation is in - `Int` => Tensor operations are only available for int tensors. The implementation is in
`burn-tensor/src/tensor/api/int.rs`. [burn-tensor/src/tensor/api/int.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/api/int.rs).
- `bool` => Tensor operations are only available for bool tensors. The implementation is in - `bool` => Tensor operations are only available for bool tensors. The implementation is in
`burn-tensor/src/tensor/api/bool.rs`. [burn-tensor/src/tensor/api/bool.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/api/bool.rs).
`Numeric` is directly implemented for `Float` and `Int` tensors, and in general, The implementations `Numeric` is directly implemented for `Float` and `Int` tensors, and in general, The implementations
for these methods are calling the corresponding `{Int|Float}TensorOp` method defined in the backend for these methods are calling the corresponding `{Int|Float}` method defined in the backend
supertrait. supertrait.
Anything that is implemented by numeric should have an implementation in the `{Int|Float}TensorOp` Anything that is implemented by numeric should have an implementation in the `{Int|Float}` traits,
traits, though it may be avoidable if the operation for one type requires casting to the other type. though it may be avoidable if the operation for one type requires casting to the other type. To
To provide an example, Powf should be implemented for `Int` tensors, but it should not be an Int provide an example, `powf` should be implemented for `Int` tensors, but it should not be an Int
Tensor Operation. The LHS should be converted to a float, and the output should be converted back to Tensor Operation. The LHS should be converted to a float, and the output should be converted back to
an int. So it's possible to avoid implementing `IntTensorOp` altogether. an int. So it's possible to avoid implementing `IntTensorOp` altogether.
Additionally there are some operations that should be defined as functions instead of tensor/tensor Additionally there are some operations that should be defined as functions instead of tensor op
op methods. these are: methods. These are:
`module` => These should be exported as functions instead of methods on tensors. The implementation `module` => These should be exported as functions instead of methods on tensors. The implementation
is in `burn-tensor/src/tensor/module.rs` (Might be moved to `tensor/api/module.rs`). `activation` => is in
These should also be exported as functions instead of methods on tensors. The implementation is in [crates/burn-tensor/src/tensor/ops/module.rs](https://github.com/tracel-ai/burn/tree/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/ops/modules).
`burn-tensor/src/tensor/activation/base.rs` (Might be moved to `tensor/api/activation.rs`). `activation` => These should also be exported as functions instead of methods on tensors. The
implementation is in
[crates/burn-tensor/src/tensor/ops/activation.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/ops/activation.rs).
Note that some activations are just a combination of backend operations and are not declared in Note that some activations are just a combination of backend operations and are not declared in
`burn-tensor/src/tensor/ops/activation.rs` there.

View File

@ -12,6 +12,10 @@ The Backend trait abstracts multiple things:
- Int tensor operations (kernels) - Int tensor operations (kernels)
- Bool tensor operations (kernels) - Bool tensor operations (kernels)
## Element types
> Warning: there are plans to change this architecture in the near future.
Even though having one type for tensors is convenient for the tensor API, it can be cumbersome when Even though having one type for tensors is convenient for the tensor API, it can be cumbersome when
implementing a backend. Therefore, backends can decide, through associated types, what types they implementing a backend. Therefore, backends can decide, through associated types, what types they
want to use for their int, float, and bool tensors. Since float and int can have multiple want to use for their int, float, and bool tensors. Since float and int can have multiple
@ -24,6 +28,8 @@ change the precision, except for the `to_full_precision` function, which ensures
on the current backend. Backend implementations can provide a way to choose the precision, which can on the current backend. Backend implementations can provide a way to choose the precision, which can
be accomplished with a generic parameter (e.g. `NdArray<f32>`). be accomplished with a generic parameter (e.g. `NdArray<f32>`).
## Operations
To be as general as possible, tensor operations are implemented as plain functions. There is no To be as general as possible, tensor operations are implemented as plain functions. There is no
object or self, just functions that take tensors as input and often return tensors as output as object or self, just functions that take tensors as input and often return tensors as output as
well. Backend implementations are free to use their own patterns to implement these kernels. Note well. Backend implementations are free to use their own patterns to implement these kernels. Note
@ -35,5 +41,5 @@ kernel executions for performance reasons.
As of now, there is only one backend decorator that supports autodiff. It follows the decorator As of now, there is only one backend decorator that supports autodiff. It follows the decorator
pattern, making any backend differentiable. However, the `AutodiffBackend` trait abstracts how pattern, making any backend differentiable. However, the `AutodiffBackend` trait abstracts how
gradients are calculated, and other approaches to autodiff might be added later. For more gradients are calculated, and other approaches to autodiff might be added later. For more
information about how the current autodiff backend works, you can read this information about how the current autodiff backend works, you can read this (slightly outdated)
[blog post](https://burn.dev/blog/burn-rusty-approach-to-tensor-handling). [blog post](https://burn.dev/blog/burn-rusty-approach-to-tensor-handling).

View File

@ -6,13 +6,13 @@ declaration of the forward pass, leaving it up to the implementer to decide how
defined. defined.
Additionally, most modules are created using a (de)serializable configuration, which defines the Additionally, most modules are created using a (de)serializable configuration, which defines the
structure of the module and its hyper-parameters. Parameters and hyper-parameters are not serialized structure of the module and its hyperparameters. Parameters and hyperparameters are not serialized
into the same file and both are normally necessary to load a module for inference. into the same file, and both are normally necessary to load a module for inference.
## Optimization ## Optimization
Optimization is normally done with gradient descent (or ascent for reinforcement learning), and it Optimization is normally done with variants of gradient descent, and it is important to provide an
is important to provide an easy API for optimizing modules. easy API for optimizing modules.
### Constraints ### Constraints
@ -26,10 +26,10 @@ is important to provide an easy API for optimizing modules.
### Solution ### Solution
`Module` trait defined in In the following, the `Module` trait is defined in
[`burn-core/src/module/base.rs`](https://github.com/tracel-ai/burn/blob/b9bd42959b0d3e755a25e383cb5b38beb25559b8/burn-core/src/module/base.rs#L83) [`crates/burn-core/src/module/base.rs`](https://github.com/tracel-ai/burn/blob/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-core/src/module/base.rs#L83)
`Optimizer` trait defined in and the `Optimizer` trait is defined in
[`burn-core/src/optim/base.rs`](https://github.com/tracel-ai/burn/blob/b9bd42959b0d3e755a25e383cb5b38beb25559b8/burn-core/src/optim/base.rs#L8) [`crates/burn-core/src/optim/base.rs`](https://github.com/tracel-ai/burn/blob/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-core/src/optim/base.rs#L8)
The solution to this problem comprises multiple parts. Firstly, the `Optimizer` trait is quite The solution to this problem comprises multiple parts. Firstly, the `Optimizer` trait is quite
similar to the `Module` trait, in terms of saving and loading the state. Please refer to the similar to the `Module` trait, in terms of saving and loading the state. Please refer to the
@ -52,9 +52,8 @@ parameter tensors.
#### SimpleOptimizer #### SimpleOptimizer
Located in Located in
[`burn-core/src/optim/simple/base.rs`](https://github.com/tracel-ai/burn/blob/b9bd42959b0d3e755a25e383cb5b38beb25559b8/burn-core/src/optim/simple/base.rs#L9) [`crates/burn-core/src/optim/simple/base.rs`](https://github.com/tracel-ai/burn/blob/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-core/src/optim/simple/base.rs#L9),
the `SimpleOptimizer` has two major assumptions:
The `SimpleOptimizer` has two major assumptions:
1. The state of the optimizer is linked to each parameter. In other words, each parameter has its 1. The state of the optimizer is linked to each parameter. In other words, each parameter has its
own optimizer state, decoupled from the other parameters. own optimizer state, decoupled from the other parameters.
@ -70,9 +69,8 @@ used.
#### OptimizerAdaptor #### OptimizerAdaptor
Located in in Located in in
[`burn-core/src/optim/simple/adapter.rs`](https://github.com/tracel-ai/burn/blob/b9bd42959b0d3e755a25e383cb5b38beb25559b8/burn-core/src/optim/simple/adaptor.rs#L14) [`crates/burn-core/src/optim/simple/adaptor.rs`](https://github.com/tracel-ai/burn/blob/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-core/src/optim/simple/adaptor.rs#L14),
the `OptimizerAdaptor` is a simple struct composed of a `SimpleOptimizer` and a hashmap with all
The `OptimizerAdaptor` is a simple struct composed of a `SimpleOptimizer` and a hashmap with all
records associated with each parameter ID. records associated with each parameter ID.
When performing an optimization step, the adaptor handles the following: When performing an optimization step, the adaptor handles the following:
@ -87,7 +85,7 @@ When performing an optimization step, the adaptor handles the following:
5. Updates the state for the current parameter and returns the updated tensor, making sure it's 5. Updates the state for the current parameter and returns the updated tensor, making sure it's
properly registered into the autodiff graph if gradients are marked as required. properly registered into the autodiff graph if gradients are marked as required.
Note that a parameter can still be updated by another process, as it is the case with running metrics Note that a parameter can still be updated by another process, as it is the case with running
used in batch norm. These tensors are still wrapped using the `Param` struct so that they are metrics used in batch norm. These tensors are still wrapped using the `Param` struct so that they
included in the module's state and given a proper parameter ID, but they are not registered in the are included in the module's state and given a proper parameter ID, but they are not registered in
autodiff graph. the autodiff graph.

View File

@ -16,7 +16,7 @@ solution.
This can include constants, database connections, other module references, or any other This can include constants, database connections, other module references, or any other
information. Only parameters should be serialized since the structure of the module itself should information. Only parameters should be serialized since the structure of the module itself should
be encapsulated with module configurations (hyper-parameters). be encapsulated with module configurations (hyperparameters).
3. **Users should be able to declare the format in which the module should be saved.** 3. **Users should be able to declare the format in which the module should be saved.**
@ -45,12 +45,12 @@ main types and traits used in the solution.
<div align="center"> <div align="center">
<h4>Module Serialization Types</h4> <h4>Module Serialization Types</h4>
<img src="./assets/ModuleSerialization.png" width="700px"/> <img src="./ModuleSerialization.png" width="700px"/>
<div align="left"> <div align="left">
The way the types interact with each other is pretty straightforward. First, a module can be The way the types interact with each other is pretty straightforward. First, a module can be
converted into a record using `into_record()`. Note that tensors can be cloned, but it won't converted into a record using `into_record()`. Note that tensors can be cloned, but it won't
actually copy any data; it will create another reference to the same data. actually copy any data; it will simply create another reference to the same data.
Then, a `Recorder` instance can be used to serialize any record. The `Recorder` has the Then, a `Recorder` instance can be used to serialize any record. The `Recorder` has the
`PrecisionSettings` type as associate type, so any record will be serialized using the settings `PrecisionSettings` type as associate type, so any record will be serialized using the settings