diff --git a/LLama.Unittest/TemplateTests.cs b/LLama.Unittest/TemplateTests.cs
index 05d131a5..3a5bb0ce 100644
--- a/LLama.Unittest/TemplateTests.cs
+++ b/LLama.Unittest/TemplateTests.cs
@@ -199,15 +199,21 @@ public sealed class TemplateTests
templater.Add("user", "4b");
templater.Add("assistant", "5");
- Assert.Equal(("user", "4a"), templater[3]);
- Assert.Equal(("assistant", "5"), templater[5]);
+ Assert.Equal("user", templater[3].Role);
+ Assert.Equal("4a", templater[3].Content);
+
+ Assert.Equal("assistant", templater[5].Role);
+ Assert.Equal("5", templater[5].Content);
Assert.Equal(6, templater.Count);
templater.RemoveAt(3);
Assert.Equal(5, templater.Count);
- Assert.Equal(("user", "4b"), templater[3]);
- Assert.Equal(("assistant", "5"), templater[4]);
+ Assert.Equal("user", templater[3].Role);
+ Assert.Equal("4b", templater[3].Content);
+
+ Assert.Equal("assistant", templater[4].Role);
+ Assert.Equal("5", templater[4].Content);
}
[Fact]
@@ -226,7 +232,8 @@ public sealed class TemplateTests
templater.RemoveAt(5);
Assert.Equal(5, templater.Count);
- Assert.Equal(("user", "4b"), templater[4]);
+ Assert.Equal("user", templater[4].Role);
+ Assert.Equal("4b", templater[4].Content);
}
[Fact]
diff --git a/LLama/LLamaTemplate.cs b/LLama/LLamaTemplate.cs
index f3032adc..c39cd0db 100644
--- a/LLama/LLamaTemplate.cs
+++ b/LLama/LLamaTemplate.cs
@@ -31,7 +31,7 @@ public sealed class LLamaTemplate
///
/// Array of messages. The property indicates how many messages there are
///
- private Message[] _messages = new Message[4];
+ private TextMessage?[] _messages = new TextMessage[4];
///
/// Backing field for
@@ -71,7 +71,7 @@ public sealed class LLamaTemplate
///
///
/// Thrown if index is less than zero or greater than or equal to
- public (string role, string content) this[int index]
+ public TextMessage this[int index]
{
get
{
@@ -80,7 +80,7 @@ public sealed class LLamaTemplate
if (index >= Count)
throw new ArgumentOutOfRangeException(nameof(index), "Index must be < Count");
- return (_messages[index].Role, _messages[index].Content);
+ return _messages[index]!;
}
}
@@ -131,30 +131,45 @@ public sealed class LLamaTemplate
}
#endregion
+ #region modify
///
/// Add a new message to the end of this template
///
///
///
- public void Add(string role, string content)
+ /// This template, for chaining calls.
+ public LLamaTemplate Add(string role, string content)
+ {
+ return Add(new TextMessage(role, content, _roleCache));
+ }
+
+ ///
+ /// Add a new message to the end of this template
+ ///
+ ///
+ /// This template, for chaining calls.
+ public LLamaTemplate Add(TextMessage message)
{
// Expand messages array if necessary
if (Count == _messages.Length)
Array.Resize(ref _messages, _messages.Length * 2);
// Add message
- _messages[Count] = new Message(role, content, _roleCache);
+ _messages[Count] = message;
Count++;
// Mark as dirty to ensure template is recalculated
_dirty = true;
+
+ return this;
}
///
/// Remove a message at the given index
///
///
- public void RemoveAt(int index)
+ /// This template, for chaining calls.
+ public LLamaTemplate RemoveAt(int index)
{
if (index < 0)
throw new ArgumentOutOfRangeException(nameof(index), "Index must be greater than or equal to zero");
@@ -169,7 +184,10 @@ public sealed class LLamaTemplate
Array.Copy(_messages, index + 1, _messages, index, Count - index);
_messages[Count] = default;
+
+ return this;
}
+ #endregion
///
/// Apply the template to the messages and write it into the output buffer
@@ -192,7 +210,8 @@ public sealed class LLamaTemplate
Array.Resize(ref _nativeChatMessages, _messages.Length);
for (var i = 0; i < Count; i++)
{
- ref var m = ref _messages[i];
+ ref var m = ref _messages[i]!;
+ Debug.Assert(m != null);
totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length;
// Pin byte arrays in place
@@ -258,17 +277,24 @@ public sealed class LLamaTemplate
}
///
- /// A message that has been added to the template, contains role and content converted into UTF8 bytes.
+ /// A message that has been added to a template
///
- private readonly record struct Message
+ public sealed class TextMessage
{
+ ///
+ /// The "role" string for this message
+ ///
public string Role { get; }
+
+ ///
+ /// The text content of this message
+ ///
public string Content { get; }
- public ReadOnlyMemory RoleBytes { get; }
- public ReadOnlyMemory ContentBytes { get; }
+ internal ReadOnlyMemory RoleBytes { get; }
+ internal ReadOnlyMemory ContentBytes { get; }
- public Message(string role, string content, Dictionary> roleCache)
+ internal TextMessage(string role, string content, IDictionary> roleCache)
{
Role = role;
Content = content;
@@ -297,5 +323,16 @@ public sealed class LLamaTemplate
Debug.Assert(contentArray.Length == encodedContentLength + 1);
ContentBytes = contentArray;
}
+
+ ///
+ /// Deconstruct this message into role and content
+ ///
+ ///
+ ///
+ public void Deconstruct(out string role, out string content)
+ {
+ role = Role;
+ content = Content;
+ }
}
}
\ No newline at end of file