mirror of https://github.com/tracel-ai/burn.git
Add layer norm onnx op support (#1680)
This commit is contained in:
parent
1718da5210
commit
e6b1b7a317
|
@ -87,7 +87,7 @@ represent the corresponding Burn Op.
|
|||
| [InstanceNormalization][79] | ❌ | ✅ |
|
||||
| [IsInf][80] | ❌ | ❌ |
|
||||
| [IsNaN][81] | ❌ | ❌ |
|
||||
| [LayerNormalization][82] | ❌ | ✅ |
|
||||
| [LayerNormalization][82] | ✅ | ✅ |
|
||||
| [LeakyRelu][83] | ✅ | ✅ |
|
||||
| [Less][84] | ❌ | ✅ |
|
||||
| [LessOrEqual][85] | ❌ | ✅ |
|
||||
|
|
|
@ -27,6 +27,7 @@ fn main() {
|
|||
.input("tests/gather/gather.onnx")
|
||||
.input("tests/gelu/gelu.onnx")
|
||||
.input("tests/global_avr_pool/global_avr_pool.onnx")
|
||||
.input("tests/layer_norm/layer_norm.onnx")
|
||||
.input("tests/linear/linear.onnx")
|
||||
.input("tests/log_softmax/log_softmax.onnx")
|
||||
.input("tests/log/log.onnx")
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/layer_norm/layer_norm.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.norm = nn.LayerNorm(4)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
|
||||
onnx_name = "layer_norm.onnx"
|
||||
test_input = torch.arange(24, dtype=torch.float, device=device).reshape(2, 3, 4)
|
||||
# LayerNormalization only appeared in opset 17
|
||||
torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=17)
|
||||
|
||||
print(f"Finished exporting model to {onnx_name}")
|
||||
|
||||
# Output some test data for use in the test
|
||||
print(f"Test input data: {test_input}")
|
||||
output = model.forward(test_input)
|
||||
print(f"Test output data: {output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -36,6 +36,7 @@ include_models!(
|
|||
gather,
|
||||
gelu,
|
||||
global_avr_pool,
|
||||
layer_norm,
|
||||
leaky_relu,
|
||||
linear,
|
||||
log_softmax,
|
||||
|
@ -600,6 +601,40 @@ mod tests {
|
|||
assert!(expected_sum.approx_eq(output_sum, (1.0e-8, 2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn layer_norm() {
|
||||
let device = Default::default();
|
||||
let model: layer_norm::Model<Backend> = layer_norm::Model::default();
|
||||
|
||||
// Run the model with ones as input for easier testing
|
||||
let input = Tensor::<Backend, 3>::from_floats(
|
||||
[
|
||||
[[0., 1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10., 11.]],
|
||||
[
|
||||
[12., 13., 14., 15.],
|
||||
[16., 17., 18., 19.],
|
||||
[20., 21., 22., 23.],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
let output = model.forward(input);
|
||||
let expected = Data::from([
|
||||
[
|
||||
[-1.3416, -0.4472, 0.4472, 1.3416],
|
||||
[-1.3416, -0.4472, 0.4472, 1.3416],
|
||||
[-1.3416, -0.4472, 0.4472, 1.3416],
|
||||
],
|
||||
[
|
||||
[-1.3416, -0.4472, 0.4472, 1.3416],
|
||||
[-1.3416, -0.4472, 0.4472, 1.3416],
|
||||
[-1.3416, -0.4472, 0.4472, 1.3416],
|
||||
],
|
||||
]);
|
||||
|
||||
output.to_data().assert_approx_eq(&expected, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn leaky_relu() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use super::layer_norm::LayerNormNode;
|
||||
use super::mask_where::WhereNode;
|
||||
use super::unsqueeze::UnsqueezeNode;
|
||||
use super::{
|
||||
|
@ -87,6 +88,7 @@ pub enum Node<PS: PrecisionSettings> {
|
|||
Dropout(DropoutNode),
|
||||
Gather(GatherNode),
|
||||
GlobalAvgPool(GlobalAvgPoolNode),
|
||||
LayerNorm(LayerNormNode<PS>),
|
||||
Linear(LinearNode<PS>),
|
||||
Matmul(MatmulNode),
|
||||
MaxPool2d(MaxPool2dNode),
|
||||
|
@ -112,6 +114,7 @@ macro_rules! match_all {
|
|||
Node::Dropout(node) => $func(node),
|
||||
Node::Gather(node) => $func(node),
|
||||
Node::GlobalAvgPool(node) => $func(node),
|
||||
Node::LayerNorm(node) => $func(node),
|
||||
Node::Linear(node) => $func(node),
|
||||
Node::Matmul(node) => $func(node),
|
||||
Node::MaxPool2d(node) => $func(node),
|
||||
|
@ -147,6 +150,7 @@ impl<PS: PrecisionSettings> Node<PS> {
|
|||
Node::Dropout(_) => "dropout",
|
||||
Node::Gather(_) => "gather",
|
||||
Node::GlobalAvgPool(_) => "global_avg_pool",
|
||||
Node::LayerNorm(_) => "layer_norm",
|
||||
Node::Linear(_) => "linear",
|
||||
Node::Matmul(_) => "matmul",
|
||||
Node::MaxPool2d(_) => "max_pool2d",
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
use super::{Node, NodeCodegen, SerializationBackend};
|
||||
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
|
||||
use burn::{
|
||||
module::{ConstantRecord, Param, ParamId},
|
||||
nn::{LayerNormConfig, LayerNormRecord},
|
||||
record::{PrecisionSettings, Record},
|
||||
tensor::{DataSerialize, Tensor},
|
||||
};
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LayerNormNode<PS: PrecisionSettings> {
|
||||
pub field: OtherType,
|
||||
pub input: TensorType,
|
||||
pub output: TensorType,
|
||||
pub gamma: DataSerialize<PS::FloatElem>, // Scale
|
||||
pub beta: Option<DataSerialize<PS::FloatElem>>, // Bias (B)
|
||||
pub config: LayerNormConfig,
|
||||
pub full_precision: bool,
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings> LayerNormNode<PS> {
|
||||
pub fn new<S: AsRef<str>>(
|
||||
name: S,
|
||||
input: TensorType,
|
||||
output: TensorType,
|
||||
gamma: DataSerialize<PS::FloatElem>,
|
||||
beta: Option<DataSerialize<PS::FloatElem>>,
|
||||
config: LayerNormConfig,
|
||||
full_precision: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
field: OtherType::new(
|
||||
name,
|
||||
quote! {
|
||||
LayerNorm<B>
|
||||
},
|
||||
),
|
||||
input,
|
||||
output,
|
||||
gamma,
|
||||
beta,
|
||||
config,
|
||||
full_precision,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for LayerNormNode<PS> {
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.input.clone())]
|
||||
}
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
}
|
||||
fn field_type(&self) -> Option<Type> {
|
||||
Some(Type::Other(self.field.clone()))
|
||||
}
|
||||
|
||||
fn field_init(&self) -> Option<TokenStream> {
|
||||
let name = &self.field.name;
|
||||
let num_features = self.config.d_model.to_tokens();
|
||||
let epsilon = self.config.epsilon;
|
||||
|
||||
let tokens = quote! {
|
||||
let #name = LayerNormConfig::new(#num_features)
|
||||
.with_epsilon(#epsilon)
|
||||
.init(device);
|
||||
};
|
||||
|
||||
Some(tokens)
|
||||
}
|
||||
|
||||
fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
let device = Default::default();
|
||||
let record = LayerNormRecord::<SerializationBackend> {
|
||||
gamma: Param::initialized(
|
||||
ParamId::new(),
|
||||
Tensor::from_data(self.gamma.clone().convert(), &device),
|
||||
),
|
||||
beta: Param::initialized(
|
||||
ParamId::new(),
|
||||
if let Some(beta) = self.beta.clone() {
|
||||
Tensor::from_data(beta.convert(), &device)
|
||||
} else {
|
||||
Tensor::zeros([self.config.d_model], &device)
|
||||
},
|
||||
),
|
||||
epsilon: ConstantRecord::new(),
|
||||
};
|
||||
|
||||
let item = Record::into_item::<PS>(record);
|
||||
item.serialize(serializer)
|
||||
}
|
||||
|
||||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
|
||||
let input = scope.tensor_use_owned(&self.input, node_position);
|
||||
let output = &self.output.name;
|
||||
let field = &self.field.name;
|
||||
|
||||
// TODO: handle self.full_precision
|
||||
quote! {
|
||||
let #output = self.#field.forward(#input);
|
||||
}
|
||||
}
|
||||
fn register_imports(&self, imports: &mut BurnImports) {
|
||||
imports.register("burn::nn::LayerNorm");
|
||||
imports.register("burn::nn::LayerNormConfig");
|
||||
}
|
||||
|
||||
fn into_node(self) -> Node<PS> {
|
||||
Node::LayerNorm(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};
|
||||
use burn::{record::FullPrecisionSettings, tensor::Data};
|
||||
|
||||
#[test]
|
||||
fn test_codegen() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(LayerNormNode::new(
|
||||
"norm",
|
||||
TensorType::new_float("input", 4),
|
||||
TensorType::new_float("output", 4),
|
||||
Data::from([2.]).serialize(),
|
||||
Some(Data::from([2.]).serialize()),
|
||||
LayerNormConfig::new(128),
|
||||
true, // full_precision isn't taken into account
|
||||
));
|
||||
|
||||
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
use burn::nn::LayerNorm;
|
||||
use burn::nn::LayerNormConfig;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model <B: Backend> {
|
||||
norm: LayerNorm<B>,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
let norm = LayerNormConfig::new(128)
|
||||
.with_epsilon(0.00001f64)
|
||||
.init(device);
|
||||
|
||||
Self {
|
||||
norm,
|
||||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let output = self.norm.forward(input);
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
}
|
|
@ -12,6 +12,7 @@ pub(crate) mod conv_transpose_2d;
|
|||
pub(crate) mod dropout;
|
||||
pub(crate) mod gather;
|
||||
pub(crate) mod global_avg_pool;
|
||||
pub(crate) mod layer_norm;
|
||||
pub(crate) mod linear;
|
||||
pub(crate) mod mask_where;
|
||||
pub(crate) mod matmul;
|
||||
|
|
|
@ -33,6 +33,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
|
|||
NodeType::GatherElements => same_as_input(node),
|
||||
NodeType::GlobalAveragePool => same_as_input(node),
|
||||
NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node),
|
||||
NodeType::LayerNormalization => same_as_input(node),
|
||||
NodeType::Linear => linear_update_outputs(node),
|
||||
NodeType::Log => same_as_input(node),
|
||||
NodeType::LogSoftmax => same_as_input(node),
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use burn::nn::{
|
||||
conv::Conv1dConfig,
|
||||
conv::{Conv2dConfig, ConvTranspose2dConfig},
|
||||
conv::{Conv1dConfig, Conv2dConfig, ConvTranspose2dConfig},
|
||||
pool::{AvgPool2dConfig, MaxPool2dConfig},
|
||||
BatchNormConfig, DropoutConfig, LinearConfig, PaddingConfig1d, PaddingConfig2d,
|
||||
BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig1d,
|
||||
PaddingConfig2d,
|
||||
};
|
||||
|
||||
use super::ir::{ArgType, AttributeValue, Data, Node};
|
||||
|
@ -465,6 +465,42 @@ pub fn batch_norm_config(node: &Node) -> BatchNormConfig {
|
|||
.with_momentum(momentum as f64)
|
||||
}
|
||||
|
||||
/// Create a LayerNormConfig from the attributes of the node
|
||||
pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) {
|
||||
// Extract the shape of the weight tensor
|
||||
let tensor_type = if let ArgType::Tensor(ref tensor_type) = node.inputs[1].ty {
|
||||
tensor_type
|
||||
} else {
|
||||
panic!("LayerNorm: weight tensor must be present");
|
||||
};
|
||||
|
||||
let num_features: usize = tensor_type.shape.clone().unwrap()[0];
|
||||
|
||||
// When `stash_type` is `1` (default), perform operations in 32-bit float and
|
||||
// cast the results back to original dtype
|
||||
let mut stash_type = 1;
|
||||
let mut axis = -1;
|
||||
let mut epsilon = 1e-5;
|
||||
|
||||
for (key, value) in node.attrs.iter() {
|
||||
match key.as_str() {
|
||||
"axis" => axis = value.clone().into_i64(),
|
||||
"epsilon" => epsilon = value.clone().into_f32(),
|
||||
"stash_type" => stash_type = value.clone().into_i64(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if axis != -1 && axis != tensor_type.dim as i64 - 1 {
|
||||
panic!("LayerNorm: normalization is only supported on the last axis right now")
|
||||
}
|
||||
|
||||
(
|
||||
LayerNormConfig::new(num_features).with_epsilon(epsilon as f64),
|
||||
stash_type == 1,
|
||||
)
|
||||
}
|
||||
|
||||
/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
|
@ -25,6 +25,7 @@ use crate::{
|
|||
dropout::DropoutNode,
|
||||
gather::GatherNode,
|
||||
global_avg_pool::GlobalAvgPoolNode,
|
||||
layer_norm::LayerNormNode,
|
||||
linear::LinearNode,
|
||||
mask_where::WhereNode,
|
||||
matmul::MatmulNode,
|
||||
|
@ -239,6 +240,9 @@ impl OnnxGraph {
|
|||
NodeType::MatMul => graph.register(Self::matmul_conversion(node)),
|
||||
NodeType::Neg => graph.register(Self::neg_conversion(node)),
|
||||
NodeType::Not => graph.register(Self::not_conversion(node)),
|
||||
NodeType::LayerNormalization => {
|
||||
graph.register(Self::layer_norm_conversion::<PS>(node))
|
||||
}
|
||||
NodeType::Linear => graph.register(Self::linear_conversion::<PS>(node)),
|
||||
NodeType::BatchNormalization => {
|
||||
graph.register(Self::batch_norm_conversion::<PS>(node))
|
||||
|
@ -635,6 +639,21 @@ impl OnnxGraph {
|
|||
)
|
||||
}
|
||||
|
||||
fn layer_norm_conversion<PS: PrecisionSettings>(node: Node) -> LayerNormNode<PS> {
|
||||
let (config, full_precision) = layer_norm_config(&node);
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
|
||||
// Scale tensor (aka gamma)
|
||||
let gamma = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Gamma is required");
|
||||
// Bias (B) optional tensor
|
||||
let beta = extract_data_serialize::<PS::FloatElem>(2, &node);
|
||||
|
||||
let name = &node.name;
|
||||
|
||||
LayerNormNode::new(name, input, output, gamma, beta, config, full_precision)
|
||||
}
|
||||
|
||||
fn conv1d_conversion<PS: PrecisionSettings>(node: Node) -> Conv1dNode<PS> {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
|
|
Loading…
Reference in New Issue