Changes based on review feedback:
- Returning template for chaining method calls - Returning a `TextMessage` object instead of a tuple
This commit is contained in:
parent
a0335f67a4
commit
4332ab3813
|
@ -199,15 +199,21 @@ public sealed class TemplateTests
|
||||||
templater.Add("user", "4b");
|
templater.Add("user", "4b");
|
||||||
templater.Add("assistant", "5");
|
templater.Add("assistant", "5");
|
||||||
|
|
||||||
Assert.Equal(("user", "4a"), templater[3]);
|
Assert.Equal("user", templater[3].Role);
|
||||||
Assert.Equal(("assistant", "5"), templater[5]);
|
Assert.Equal("4a", templater[3].Content);
|
||||||
|
|
||||||
|
Assert.Equal("assistant", templater[5].Role);
|
||||||
|
Assert.Equal("5", templater[5].Content);
|
||||||
|
|
||||||
Assert.Equal(6, templater.Count);
|
Assert.Equal(6, templater.Count);
|
||||||
templater.RemoveAt(3);
|
templater.RemoveAt(3);
|
||||||
Assert.Equal(5, templater.Count);
|
Assert.Equal(5, templater.Count);
|
||||||
|
|
||||||
Assert.Equal(("user", "4b"), templater[3]);
|
Assert.Equal("user", templater[3].Role);
|
||||||
Assert.Equal(("assistant", "5"), templater[4]);
|
Assert.Equal("4b", templater[3].Content);
|
||||||
|
|
||||||
|
Assert.Equal("assistant", templater[4].Role);
|
||||||
|
Assert.Equal("5", templater[4].Content);
|
||||||
}
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
|
@ -226,7 +232,8 @@ public sealed class TemplateTests
|
||||||
templater.RemoveAt(5);
|
templater.RemoveAt(5);
|
||||||
Assert.Equal(5, templater.Count);
|
Assert.Equal(5, templater.Count);
|
||||||
|
|
||||||
Assert.Equal(("user", "4b"), templater[4]);
|
Assert.Equal("user", templater[4].Role);
|
||||||
|
Assert.Equal("4b", templater[4].Content);
|
||||||
}
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
|
|
|
@ -31,7 +31,7 @@ public sealed class LLamaTemplate
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Array of messages. The <see cref="Count"/> property indicates how many messages there are
|
/// Array of messages. The <see cref="Count"/> property indicates how many messages there are
|
||||||
/// </summary>
|
/// </summary>
|
||||||
private Message[] _messages = new Message[4];
|
private TextMessage?[] _messages = new TextMessage[4];
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Backing field for <see cref="AddAssistant"/>
|
/// Backing field for <see cref="AddAssistant"/>
|
||||||
|
@ -71,7 +71,7 @@ public sealed class LLamaTemplate
|
||||||
/// <param name="index"></param>
|
/// <param name="index"></param>
|
||||||
/// <returns></returns>
|
/// <returns></returns>
|
||||||
/// <exception cref="ArgumentOutOfRangeException">Thrown if index is less than zero or greater than or equal to <see cref="Count"/></exception>
|
/// <exception cref="ArgumentOutOfRangeException">Thrown if index is less than zero or greater than or equal to <see cref="Count"/></exception>
|
||||||
public (string role, string content) this[int index]
|
public TextMessage this[int index]
|
||||||
{
|
{
|
||||||
get
|
get
|
||||||
{
|
{
|
||||||
|
@ -80,7 +80,7 @@ public sealed class LLamaTemplate
|
||||||
if (index >= Count)
|
if (index >= Count)
|
||||||
throw new ArgumentOutOfRangeException(nameof(index), "Index must be < Count");
|
throw new ArgumentOutOfRangeException(nameof(index), "Index must be < Count");
|
||||||
|
|
||||||
return (_messages[index].Role, _messages[index].Content);
|
return _messages[index]!;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,30 +131,45 @@ public sealed class LLamaTemplate
|
||||||
}
|
}
|
||||||
#endregion
|
#endregion
|
||||||
|
|
||||||
|
#region modify
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Add a new message to the end of this template
|
/// Add a new message to the end of this template
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="role"></param>
|
/// <param name="role"></param>
|
||||||
/// <param name="content"></param>
|
/// <param name="content"></param>
|
||||||
public void Add(string role, string content)
|
/// <returns>This template, for chaining calls.</returns>
|
||||||
|
public LLamaTemplate Add(string role, string content)
|
||||||
|
{
|
||||||
|
return Add(new TextMessage(role, content, _roleCache));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Add a new message to the end of this template
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="message"></param>
|
||||||
|
/// <returns>This template, for chaining calls.</returns>
|
||||||
|
public LLamaTemplate Add(TextMessage message)
|
||||||
{
|
{
|
||||||
// Expand messages array if necessary
|
// Expand messages array if necessary
|
||||||
if (Count == _messages.Length)
|
if (Count == _messages.Length)
|
||||||
Array.Resize(ref _messages, _messages.Length * 2);
|
Array.Resize(ref _messages, _messages.Length * 2);
|
||||||
|
|
||||||
// Add message
|
// Add message
|
||||||
_messages[Count] = new Message(role, content, _roleCache);
|
_messages[Count] = message;
|
||||||
Count++;
|
Count++;
|
||||||
|
|
||||||
// Mark as dirty to ensure template is recalculated
|
// Mark as dirty to ensure template is recalculated
|
||||||
_dirty = true;
|
_dirty = true;
|
||||||
|
|
||||||
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Remove a message at the given index
|
/// Remove a message at the given index
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="index"></param>
|
/// <param name="index"></param>
|
||||||
public void RemoveAt(int index)
|
/// <returns>This template, for chaining calls.</returns>
|
||||||
|
public LLamaTemplate RemoveAt(int index)
|
||||||
{
|
{
|
||||||
if (index < 0)
|
if (index < 0)
|
||||||
throw new ArgumentOutOfRangeException(nameof(index), "Index must be greater than or equal to zero");
|
throw new ArgumentOutOfRangeException(nameof(index), "Index must be greater than or equal to zero");
|
||||||
|
@ -169,7 +184,10 @@ public sealed class LLamaTemplate
|
||||||
Array.Copy(_messages, index + 1, _messages, index, Count - index);
|
Array.Copy(_messages, index + 1, _messages, index, Count - index);
|
||||||
|
|
||||||
_messages[Count] = default;
|
_messages[Count] = default;
|
||||||
|
|
||||||
|
return this;
|
||||||
}
|
}
|
||||||
|
#endregion
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Apply the template to the messages and write it into the output buffer
|
/// Apply the template to the messages and write it into the output buffer
|
||||||
|
@ -192,7 +210,8 @@ public sealed class LLamaTemplate
|
||||||
Array.Resize(ref _nativeChatMessages, _messages.Length);
|
Array.Resize(ref _nativeChatMessages, _messages.Length);
|
||||||
for (var i = 0; i < Count; i++)
|
for (var i = 0; i < Count; i++)
|
||||||
{
|
{
|
||||||
ref var m = ref _messages[i];
|
ref var m = ref _messages[i]!;
|
||||||
|
Debug.Assert(m != null);
|
||||||
totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length;
|
totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length;
|
||||||
|
|
||||||
// Pin byte arrays in place
|
// Pin byte arrays in place
|
||||||
|
@ -258,17 +277,24 @@ public sealed class LLamaTemplate
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// A message that has been added to the template, contains role and content converted into UTF8 bytes.
|
/// A message that has been added to a template
|
||||||
/// </summary>
|
/// </summary>
|
||||||
private readonly record struct Message
|
public sealed class TextMessage
|
||||||
{
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// The "role" string for this message
|
||||||
|
/// </summary>
|
||||||
public string Role { get; }
|
public string Role { get; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The text content of this message
|
||||||
|
/// </summary>
|
||||||
public string Content { get; }
|
public string Content { get; }
|
||||||
|
|
||||||
public ReadOnlyMemory<byte> RoleBytes { get; }
|
internal ReadOnlyMemory<byte> RoleBytes { get; }
|
||||||
public ReadOnlyMemory<byte> ContentBytes { get; }
|
internal ReadOnlyMemory<byte> ContentBytes { get; }
|
||||||
|
|
||||||
public Message(string role, string content, Dictionary<string, ReadOnlyMemory<byte>> roleCache)
|
internal TextMessage(string role, string content, IDictionary<string, ReadOnlyMemory<byte>> roleCache)
|
||||||
{
|
{
|
||||||
Role = role;
|
Role = role;
|
||||||
Content = content;
|
Content = content;
|
||||||
|
@ -297,5 +323,16 @@ public sealed class LLamaTemplate
|
||||||
Debug.Assert(contentArray.Length == encodedContentLength + 1);
|
Debug.Assert(contentArray.Length == encodedContentLength + 1);
|
||||||
ContentBytes = contentArray;
|
ContentBytes = contentArray;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Deconstruct this message into role and content
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="role"></param>
|
||||||
|
/// <param name="content"></param>
|
||||||
|
public void Deconstruct(out string role, out string content)
|
||||||
|
{
|
||||||
|
role = Role;
|
||||||
|
content = Content;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue