Add layer norm onnx op support (#1680)

This commit is contained in:
Guillaume Lagrange 2024-04-23 11:19:07 -04:00 committed by GitHub
parent 1718da5210
commit e6b1b7a317
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 319 additions and 4 deletions

View File

@ -87,7 +87,7 @@ represent the corresponding Burn Op.
| [InstanceNormalization][79] | ❌ | ✅ |
| [IsInf][80] | ❌ | ❌ |
| [IsNaN][81] | ❌ | ❌ |
| [LayerNormalization][82] | | ✅ |
| [LayerNormalization][82] | | ✅ |
| [LeakyRelu][83] | ✅ | ✅ |
| [Less][84] | ❌ | ✅ |
| [LessOrEqual][85] | ❌ | ✅ |

View File

@ -27,6 +27,7 @@ fn main() {
.input("tests/gather/gather.onnx")
.input("tests/gelu/gelu.onnx")
.input("tests/global_avr_pool/global_avr_pool.onnx")
.input("tests/layer_norm/layer_norm.onnx")
.input("tests/linear/linear.onnx")
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/log/log.onnx")

View File

@ -0,0 +1,41 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/layer_norm/layer_norm.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.norm = nn.LayerNorm(4)
def forward(self, x):
return self.norm(x)
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "layer_norm.onnx"
test_input = torch.arange(24, dtype=torch.float, device=device).reshape(2, 3, 4)
# LayerNormalization only appeared in opset 17
torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=17)
print(f"Finished exporting model to {onnx_name}")
# Output some test data for use in the test
print(f"Test input data: {test_input}")
output = model.forward(test_input)
print(f"Test output data: {output}")
if __name__ == "__main__":
main()

View File

@ -36,6 +36,7 @@ include_models!(
gather,
gelu,
global_avr_pool,
layer_norm,
leaky_relu,
linear,
log_softmax,
@ -600,6 +601,40 @@ mod tests {
assert!(expected_sum.approx_eq(output_sum, (1.0e-8, 2)));
}
#[test]
fn layer_norm() {
let device = Default::default();
let model: layer_norm::Model<Backend> = layer_norm::Model::default();
// Run the model with ones as input for easier testing
let input = Tensor::<Backend, 3>::from_floats(
[
[[0., 1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10., 11.]],
[
[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.],
],
],
&device,
);
let output = model.forward(input);
let expected = Data::from([
[
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
],
[
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
],
]);
output.to_data().assert_approx_eq(&expected, 4);
}
#[test]
fn leaky_relu() {
// Initialize the model without weights (because the exported file does not contain them)

View File

@ -1,3 +1,4 @@
use super::layer_norm::LayerNormNode;
use super::mask_where::WhereNode;
use super::unsqueeze::UnsqueezeNode;
use super::{
@ -87,6 +88,7 @@ pub enum Node<PS: PrecisionSettings> {
Dropout(DropoutNode),
Gather(GatherNode),
GlobalAvgPool(GlobalAvgPoolNode),
LayerNorm(LayerNormNode<PS>),
Linear(LinearNode<PS>),
Matmul(MatmulNode),
MaxPool2d(MaxPool2dNode),
@ -112,6 +114,7 @@ macro_rules! match_all {
Node::Dropout(node) => $func(node),
Node::Gather(node) => $func(node),
Node::GlobalAvgPool(node) => $func(node),
Node::LayerNorm(node) => $func(node),
Node::Linear(node) => $func(node),
Node::Matmul(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
@ -147,6 +150,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Dropout(_) => "dropout",
Node::Gather(_) => "gather",
Node::GlobalAvgPool(_) => "global_avg_pool",
Node::LayerNorm(_) => "layer_norm",
Node::Linear(_) => "linear",
Node::Matmul(_) => "matmul",
Node::MaxPool2d(_) => "max_pool2d",

View File

@ -0,0 +1,177 @@
use super::{Node, NodeCodegen, SerializationBackend};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
use burn::{
module::{ConstantRecord, Param, ParamId},
nn::{LayerNormConfig, LayerNormRecord},
record::{PrecisionSettings, Record},
tensor::{DataSerialize, Tensor},
};
use proc_macro2::TokenStream;
use quote::quote;
use serde::Serialize;
#[derive(Debug, Clone)]
pub struct LayerNormNode<PS: PrecisionSettings> {
pub field: OtherType,
pub input: TensorType,
pub output: TensorType,
pub gamma: DataSerialize<PS::FloatElem>, // Scale
pub beta: Option<DataSerialize<PS::FloatElem>>, // Bias (B)
pub config: LayerNormConfig,
pub full_precision: bool,
}
impl<PS: PrecisionSettings> LayerNormNode<PS> {
pub fn new<S: AsRef<str>>(
name: S,
input: TensorType,
output: TensorType,
gamma: DataSerialize<PS::FloatElem>,
beta: Option<DataSerialize<PS::FloatElem>>,
config: LayerNormConfig,
full_precision: bool,
) -> Self {
Self {
field: OtherType::new(
name,
quote! {
LayerNorm<B>
},
),
input,
output,
gamma,
beta,
config,
full_precision,
}
}
}
impl<PS: PrecisionSettings> NodeCodegen<PS> for LayerNormNode<PS> {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(self.field.clone()))
}
fn field_init(&self) -> Option<TokenStream> {
let name = &self.field.name;
let num_features = self.config.d_model.to_tokens();
let epsilon = self.config.epsilon;
let tokens = quote! {
let #name = LayerNormConfig::new(#num_features)
.with_epsilon(#epsilon)
.init(device);
};
Some(tokens)
}
fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let device = Default::default();
let record = LayerNormRecord::<SerializationBackend> {
gamma: Param::initialized(
ParamId::new(),
Tensor::from_data(self.gamma.clone().convert(), &device),
),
beta: Param::initialized(
ParamId::new(),
if let Some(beta) = self.beta.clone() {
Tensor::from_data(beta.convert(), &device)
} else {
Tensor::zeros([self.config.d_model], &device)
},
),
epsilon: ConstantRecord::new(),
};
let item = Record::into_item::<PS>(record);
item.serialize(serializer)
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let field = &self.field.name;
// TODO: handle self.full_precision
quote! {
let #output = self.#field.forward(#input);
}
}
fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::nn::LayerNorm");
imports.register("burn::nn::LayerNormConfig");
}
fn into_node(self) -> Node<PS> {
Node::LayerNorm(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};
use burn::{record::FullPrecisionSettings, tensor::Data};
#[test]
fn test_codegen() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(LayerNormNode::new(
"norm",
TensorType::new_float("input", 4),
TensorType::new_float("output", 4),
Data::from([2.]).serialize(),
Some(Data::from([2.]).serialize()),
LayerNormConfig::new(128),
true, // full_precision isn't taken into account
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
use burn::nn::LayerNorm;
use burn::nn::LayerNormConfig;
#[derive(Module, Debug)]
pub struct Model <B: Backend> {
norm: LayerNorm<B>,
phantom: core::marker::PhantomData<B>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let norm = LayerNormConfig::new(128)
.with_epsilon(0.00001f64)
.init(device);
Self {
norm,
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output = self.norm.forward(input);
output
}
}
};
assert_tokens(graph.codegen(), expected);
}
}

View File

@ -12,6 +12,7 @@ pub(crate) mod conv_transpose_2d;
pub(crate) mod dropout;
pub(crate) mod gather;
pub(crate) mod global_avg_pool;
pub(crate) mod layer_norm;
pub(crate) mod linear;
pub(crate) mod mask_where;
pub(crate) mod matmul;

View File

@ -33,6 +33,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::GatherElements => same_as_input(node),
NodeType::GlobalAveragePool => same_as_input(node),
NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node),
NodeType::LayerNormalization => same_as_input(node),
NodeType::Linear => linear_update_outputs(node),
NodeType::Log => same_as_input(node),
NodeType::LogSoftmax => same_as_input(node),

View File

@ -1,8 +1,8 @@
use burn::nn::{
conv::Conv1dConfig,
conv::{Conv2dConfig, ConvTranspose2dConfig},
conv::{Conv1dConfig, Conv2dConfig, ConvTranspose2dConfig},
pool::{AvgPool2dConfig, MaxPool2dConfig},
BatchNormConfig, DropoutConfig, LinearConfig, PaddingConfig1d, PaddingConfig2d,
BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig1d,
PaddingConfig2d,
};
use super::ir::{ArgType, AttributeValue, Data, Node};
@ -465,6 +465,42 @@ pub fn batch_norm_config(node: &Node) -> BatchNormConfig {
.with_momentum(momentum as f64)
}
/// Create a LayerNormConfig from the attributes of the node
pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) {
// Extract the shape of the weight tensor
let tensor_type = if let ArgType::Tensor(ref tensor_type) = node.inputs[1].ty {
tensor_type
} else {
panic!("LayerNorm: weight tensor must be present");
};
let num_features: usize = tensor_type.shape.clone().unwrap()[0];
// When `stash_type` is `1` (default), perform operations in 32-bit float and
// cast the results back to original dtype
let mut stash_type = 1;
let mut axis = -1;
let mut epsilon = 1e-5;
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axis" => axis = value.clone().into_i64(),
"epsilon" => epsilon = value.clone().into_f32(),
"stash_type" => stash_type = value.clone().into_i64(),
_ => {}
}
}
if axis != -1 && axis != tensor_type.dim as i64 - 1 {
panic!("LayerNorm: normalization is only supported on the last axis right now")
}
(
LayerNormConfig::new(num_features).with_epsilon(epsilon as f64),
stash_type == 1,
)
}
/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling.
///
/// # Arguments

View File

@ -25,6 +25,7 @@ use crate::{
dropout::DropoutNode,
gather::GatherNode,
global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode,
linear::LinearNode,
mask_where::WhereNode,
matmul::MatmulNode,
@ -239,6 +240,9 @@ impl OnnxGraph {
NodeType::MatMul => graph.register(Self::matmul_conversion(node)),
NodeType::Neg => graph.register(Self::neg_conversion(node)),
NodeType::Not => graph.register(Self::not_conversion(node)),
NodeType::LayerNormalization => {
graph.register(Self::layer_norm_conversion::<PS>(node))
}
NodeType::Linear => graph.register(Self::linear_conversion::<PS>(node)),
NodeType::BatchNormalization => {
graph.register(Self::batch_norm_conversion::<PS>(node))
@ -635,6 +639,21 @@ impl OnnxGraph {
)
}
fn layer_norm_conversion<PS: PrecisionSettings>(node: Node) -> LayerNormNode<PS> {
let (config, full_precision) = layer_norm_config(&node);
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
// Scale tensor (aka gamma)
let gamma = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Gamma is required");
// Bias (B) optional tensor
let beta = extract_data_serialize::<PS::FloatElem>(2, &node);
let name = &node.name;
LayerNormNode::new(name, input, output, gamma, beta, config, full_precision)
}
fn conv1d_conversion<PS: PrecisionSettings>(node: Node) -> Conv1dNode<PS> {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();