Changes based on review feedback:

- Returning template for chaining method calls
 - Returning a `TextMessage` object instead of a tuple
This commit is contained in:
Martin Evans 2024-05-06 14:04:47 +01:00
parent a0335f67a4
commit 4332ab3813
2 changed files with 61 additions and 17 deletions

View File

@ -199,15 +199,21 @@ public sealed class TemplateTests
templater.Add("user", "4b"); templater.Add("user", "4b");
templater.Add("assistant", "5"); templater.Add("assistant", "5");
Assert.Equal(("user", "4a"), templater[3]); Assert.Equal("user", templater[3].Role);
Assert.Equal(("assistant", "5"), templater[5]); Assert.Equal("4a", templater[3].Content);
Assert.Equal("assistant", templater[5].Role);
Assert.Equal("5", templater[5].Content);
Assert.Equal(6, templater.Count); Assert.Equal(6, templater.Count);
templater.RemoveAt(3); templater.RemoveAt(3);
Assert.Equal(5, templater.Count); Assert.Equal(5, templater.Count);
Assert.Equal(("user", "4b"), templater[3]); Assert.Equal("user", templater[3].Role);
Assert.Equal(("assistant", "5"), templater[4]); Assert.Equal("4b", templater[3].Content);
Assert.Equal("assistant", templater[4].Role);
Assert.Equal("5", templater[4].Content);
} }
[Fact] [Fact]
@ -226,7 +232,8 @@ public sealed class TemplateTests
templater.RemoveAt(5); templater.RemoveAt(5);
Assert.Equal(5, templater.Count); Assert.Equal(5, templater.Count);
Assert.Equal(("user", "4b"), templater[4]); Assert.Equal("user", templater[4].Role);
Assert.Equal("4b", templater[4].Content);
} }
[Fact] [Fact]

View File

@ -31,7 +31,7 @@ public sealed class LLamaTemplate
/// <summary> /// <summary>
/// Array of messages. The <see cref="Count"/> property indicates how many messages there are /// Array of messages. The <see cref="Count"/> property indicates how many messages there are
/// </summary> /// </summary>
private Message[] _messages = new Message[4]; private TextMessage?[] _messages = new TextMessage[4];
/// <summary> /// <summary>
/// Backing field for <see cref="AddAssistant"/> /// Backing field for <see cref="AddAssistant"/>
@ -71,7 +71,7 @@ public sealed class LLamaTemplate
/// <param name="index"></param> /// <param name="index"></param>
/// <returns></returns> /// <returns></returns>
/// <exception cref="ArgumentOutOfRangeException">Thrown if index is less than zero or greater than or equal to <see cref="Count"/></exception> /// <exception cref="ArgumentOutOfRangeException">Thrown if index is less than zero or greater than or equal to <see cref="Count"/></exception>
public (string role, string content) this[int index] public TextMessage this[int index]
{ {
get get
{ {
@ -80,7 +80,7 @@ public sealed class LLamaTemplate
if (index >= Count) if (index >= Count)
throw new ArgumentOutOfRangeException(nameof(index), "Index must be < Count"); throw new ArgumentOutOfRangeException(nameof(index), "Index must be < Count");
return (_messages[index].Role, _messages[index].Content); return _messages[index]!;
} }
} }
@ -131,30 +131,45 @@ public sealed class LLamaTemplate
} }
#endregion #endregion
#region modify
/// <summary> /// <summary>
/// Add a new message to the end of this template /// Add a new message to the end of this template
/// </summary> /// </summary>
/// <param name="role"></param> /// <param name="role"></param>
/// <param name="content"></param> /// <param name="content"></param>
public void Add(string role, string content) /// <returns>This template, for chaining calls.</returns>
public LLamaTemplate Add(string role, string content)
{
return Add(new TextMessage(role, content, _roleCache));
}
/// <summary>
/// Add a new message to the end of this template
/// </summary>
/// <param name="message"></param>
/// <returns>This template, for chaining calls.</returns>
public LLamaTemplate Add(TextMessage message)
{ {
// Expand messages array if necessary // Expand messages array if necessary
if (Count == _messages.Length) if (Count == _messages.Length)
Array.Resize(ref _messages, _messages.Length * 2); Array.Resize(ref _messages, _messages.Length * 2);
// Add message // Add message
_messages[Count] = new Message(role, content, _roleCache); _messages[Count] = message;
Count++; Count++;
// Mark as dirty to ensure template is recalculated // Mark as dirty to ensure template is recalculated
_dirty = true; _dirty = true;
return this;
} }
/// <summary> /// <summary>
/// Remove a message at the given index /// Remove a message at the given index
/// </summary> /// </summary>
/// <param name="index"></param> /// <param name="index"></param>
public void RemoveAt(int index) /// <returns>This template, for chaining calls.</returns>
public LLamaTemplate RemoveAt(int index)
{ {
if (index < 0) if (index < 0)
throw new ArgumentOutOfRangeException(nameof(index), "Index must be greater than or equal to zero"); throw new ArgumentOutOfRangeException(nameof(index), "Index must be greater than or equal to zero");
@ -169,7 +184,10 @@ public sealed class LLamaTemplate
Array.Copy(_messages, index + 1, _messages, index, Count - index); Array.Copy(_messages, index + 1, _messages, index, Count - index);
_messages[Count] = default; _messages[Count] = default;
return this;
} }
#endregion
/// <summary> /// <summary>
/// Apply the template to the messages and write it into the output buffer /// Apply the template to the messages and write it into the output buffer
@ -192,7 +210,8 @@ public sealed class LLamaTemplate
Array.Resize(ref _nativeChatMessages, _messages.Length); Array.Resize(ref _nativeChatMessages, _messages.Length);
for (var i = 0; i < Count; i++) for (var i = 0; i < Count; i++)
{ {
ref var m = ref _messages[i]; ref var m = ref _messages[i]!;
Debug.Assert(m != null);
totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length; totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length;
// Pin byte arrays in place // Pin byte arrays in place
@ -258,17 +277,24 @@ public sealed class LLamaTemplate
} }
/// <summary> /// <summary>
/// A message that has been added to the template, contains role and content converted into UTF8 bytes. /// A message that has been added to a template
/// </summary> /// </summary>
private readonly record struct Message public sealed class TextMessage
{ {
/// <summary>
/// The "role" string for this message
/// </summary>
public string Role { get; } public string Role { get; }
/// <summary>
/// The text content of this message
/// </summary>
public string Content { get; } public string Content { get; }
public ReadOnlyMemory<byte> RoleBytes { get; } internal ReadOnlyMemory<byte> RoleBytes { get; }
public ReadOnlyMemory<byte> ContentBytes { get; } internal ReadOnlyMemory<byte> ContentBytes { get; }
public Message(string role, string content, Dictionary<string, ReadOnlyMemory<byte>> roleCache) internal TextMessage(string role, string content, IDictionary<string, ReadOnlyMemory<byte>> roleCache)
{ {
Role = role; Role = role;
Content = content; Content = content;
@ -297,5 +323,16 @@ public sealed class LLamaTemplate
Debug.Assert(contentArray.Length == encodedContentLength + 1); Debug.Assert(contentArray.Length == encodedContentLength + 1);
ContentBytes = contentArray; ContentBytes = contentArray;
} }
/// <summary>
/// Deconstruct this message into role and content
/// </summary>
/// <param name="role"></param>
/// <param name="content"></param>
public void Deconstruct(out string role, out string content)
{
role = Role;
content = Content;
}
} }
} }