LLamaSharp/LLama.Unittest/TemplateTests.cs

252 lines
8.2 KiB
C#

using System.Text;
using LLama.Common;
using LLama.Native;
namespace LLama.Unittest;
public sealed class TemplateTests
: IDisposable
{
private readonly LLamaWeights _model;
public TemplateTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 1,
GpuLayerCount = Constants.CIGpuLayerCount
};
_model = LLamaWeights.LoadFromFile(@params);
}
public void Dispose()
{
_model.Dispose();
}
[Fact]
public void BasicTemplate()
{
var templater = new LLamaTemplate(_model);
Assert.Equal(0, templater.Count);
templater.Add("assistant", "hello");
Assert.Equal(1, templater.Count);
templater.Add("user", "world");
Assert.Equal(2, templater.Count);
templater.Add("assistant", "111");
Assert.Equal(3, templater.Count);
templater.Add("user", "aaa");
Assert.Equal(4, templater.Count);
templater.Add("assistant", "222");
Assert.Equal(5, templater.Count);
templater.Add("user", "bbb");
Assert.Equal(6, templater.Count);
templater.Add("assistant", "333");
Assert.Equal(7, templater.Count);
templater.Add("user", "ccc");
Assert.Equal(8, templater.Count);
// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];
Assert.Equal(8, templater.Count);
// Call again to get contents
length = templater.Apply(dest);
Assert.Equal(8, templater.Count);
var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
"<|im_start|>user\nworld<|im_end|>\n" +
"<|im_start|>assistant\n" +
"111<|im_end|>" +
"\n<|im_start|>user\n" +
"aaa<|im_end|>\n" +
"<|im_start|>assistant\n" +
"222<|im_end|>\n" +
"<|im_start|>user\n" +
"bbb<|im_end|>\n" +
"<|im_start|>assistant\n" +
"333<|im_end|>\n" +
"<|im_start|>user\n" +
"ccc<|im_end|>\n";
Assert.Equal(expected, templateResult);
}
[Fact]
public void CustomTemplate()
{
var templater = new LLamaTemplate("gemma");
Assert.Equal(0, templater.Count);
templater.Add("assistant", "hello");
Assert.Equal(1, templater.Count);
templater.Add("user", "world");
Assert.Equal(2, templater.Count);
templater.Add("assistant", "111");
Assert.Equal(3, templater.Count);
templater.Add("user", "aaa");
Assert.Equal(4, templater.Count);
// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];
Assert.Equal(4, templater.Count);
// Call again to get contents
length = templater.Apply(dest);
Assert.Equal(4, templater.Count);
var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
const string expected = "<start_of_turn>model\n" +
"hello<end_of_turn>\n" +
"<start_of_turn>user\n" +
"world<end_of_turn>\n" +
"<start_of_turn>model\n" +
"111<end_of_turn>\n" +
"<start_of_turn>user\n" +
"aaa<end_of_turn>\n";
Assert.Equal(expected, templateResult);
}
[Fact]
public void BasicTemplateWithAddAssistant()
{
var templater = new LLamaTemplate(_model)
{
AddAssistant = true,
};
Assert.Equal(0, templater.Count);
templater.Add("assistant", "hello");
Assert.Equal(1, templater.Count);
templater.Add("user", "world");
Assert.Equal(2, templater.Count);
templater.Add("assistant", "111");
Assert.Equal(3, templater.Count);
templater.Add("user", "aaa");
Assert.Equal(4, templater.Count);
templater.Add("assistant", "222");
Assert.Equal(5, templater.Count);
templater.Add("user", "bbb");
Assert.Equal(6, templater.Count);
templater.Add("assistant", "333");
Assert.Equal(7, templater.Count);
templater.Add("user", "ccc");
Assert.Equal(8, templater.Count);
// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];
Assert.Equal(8, templater.Count);
// Call again to get contents
length = templater.Apply(dest);
Assert.Equal(8, templater.Count);
var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
"<|im_start|>user\nworld<|im_end|>\n" +
"<|im_start|>assistant\n" +
"111<|im_end|>" +
"\n<|im_start|>user\n" +
"aaa<|im_end|>\n" +
"<|im_start|>assistant\n" +
"222<|im_end|>\n" +
"<|im_start|>user\n" +
"bbb<|im_end|>\n" +
"<|im_start|>assistant\n" +
"333<|im_end|>\n" +
"<|im_start|>user\n" +
"ccc<|im_end|>\n" +
"<|im_start|>assistant\n";
Assert.Equal(expected, templateResult);
}
[Fact]
public void GetOutOfRangeThrows()
{
var templater = new LLamaTemplate(_model);
Assert.Throws<ArgumentOutOfRangeException>(() => templater[0]);
templater.Add("assistant", "1");
templater.Add("user", "2");
Assert.Throws<ArgumentOutOfRangeException>(() => templater[-1]);
Assert.Throws<ArgumentOutOfRangeException>(() => templater[2]);
}
[Fact]
public void RemoveMid()
{
var templater = new LLamaTemplate(_model);
templater.Add("assistant", "1");
templater.Add("user", "2");
templater.Add("assistant", "3");
templater.Add("user", "4a");
templater.Add("user", "4b");
templater.Add("assistant", "5");
Assert.Equal("user", templater[3].Role);
Assert.Equal("4a", templater[3].Content);
Assert.Equal("assistant", templater[5].Role);
Assert.Equal("5", templater[5].Content);
Assert.Equal(6, templater.Count);
templater.RemoveAt(3);
Assert.Equal(5, templater.Count);
Assert.Equal("user", templater[3].Role);
Assert.Equal("4b", templater[3].Content);
Assert.Equal("assistant", templater[4].Role);
Assert.Equal("5", templater[4].Content);
}
[Fact]
public void RemoveLast()
{
var templater = new LLamaTemplate(_model);
templater.Add("assistant", "1");
templater.Add("user", "2");
templater.Add("assistant", "3");
templater.Add("user", "4a");
templater.Add("user", "4b");
templater.Add("assistant", "5");
Assert.Equal(6, templater.Count);
templater.RemoveAt(5);
Assert.Equal(5, templater.Count);
Assert.Equal("user", templater[4].Role);
Assert.Equal("4b", templater[4].Content);
}
[Fact]
public void RemoveOutOfRange()
{
var templater = new LLamaTemplate(_model);
Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(0));
templater.Add("assistant", "1");
templater.Add("user", "2");
Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(-1));
Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(2));
}
}