diff --git a/LLama.Unittest/TemplateTests.cs b/LLama.Unittest/TemplateTests.cs index 05d131a5..3a5bb0ce 100644 --- a/LLama.Unittest/TemplateTests.cs +++ b/LLama.Unittest/TemplateTests.cs @@ -199,15 +199,21 @@ public sealed class TemplateTests templater.Add("user", "4b"); templater.Add("assistant", "5"); - Assert.Equal(("user", "4a"), templater[3]); - Assert.Equal(("assistant", "5"), templater[5]); + Assert.Equal("user", templater[3].Role); + Assert.Equal("4a", templater[3].Content); + + Assert.Equal("assistant", templater[5].Role); + Assert.Equal("5", templater[5].Content); Assert.Equal(6, templater.Count); templater.RemoveAt(3); Assert.Equal(5, templater.Count); - Assert.Equal(("user", "4b"), templater[3]); - Assert.Equal(("assistant", "5"), templater[4]); + Assert.Equal("user", templater[3].Role); + Assert.Equal("4b", templater[3].Content); + + Assert.Equal("assistant", templater[4].Role); + Assert.Equal("5", templater[4].Content); } [Fact] @@ -226,7 +232,8 @@ public sealed class TemplateTests templater.RemoveAt(5); Assert.Equal(5, templater.Count); - Assert.Equal(("user", "4b"), templater[4]); + Assert.Equal("user", templater[4].Role); + Assert.Equal("4b", templater[4].Content); } [Fact] diff --git a/LLama/LLamaTemplate.cs b/LLama/LLamaTemplate.cs index f3032adc..c39cd0db 100644 --- a/LLama/LLamaTemplate.cs +++ b/LLama/LLamaTemplate.cs @@ -31,7 +31,7 @@ public sealed class LLamaTemplate /// /// Array of messages. The property indicates how many messages there are /// - private Message[] _messages = new Message[4]; + private TextMessage?[] _messages = new TextMessage[4]; /// /// Backing field for @@ -71,7 +71,7 @@ public sealed class LLamaTemplate /// /// /// Thrown if index is less than zero or greater than or equal to - public (string role, string content) this[int index] + public TextMessage this[int index] { get { @@ -80,7 +80,7 @@ public sealed class LLamaTemplate if (index >= Count) throw new ArgumentOutOfRangeException(nameof(index), "Index must be < Count"); - return (_messages[index].Role, _messages[index].Content); + return _messages[index]!; } } @@ -131,30 +131,45 @@ public sealed class LLamaTemplate } #endregion + #region modify /// /// Add a new message to the end of this template /// /// /// - public void Add(string role, string content) + /// This template, for chaining calls. + public LLamaTemplate Add(string role, string content) + { + return Add(new TextMessage(role, content, _roleCache)); + } + + /// + /// Add a new message to the end of this template + /// + /// + /// This template, for chaining calls. + public LLamaTemplate Add(TextMessage message) { // Expand messages array if necessary if (Count == _messages.Length) Array.Resize(ref _messages, _messages.Length * 2); // Add message - _messages[Count] = new Message(role, content, _roleCache); + _messages[Count] = message; Count++; // Mark as dirty to ensure template is recalculated _dirty = true; + + return this; } /// /// Remove a message at the given index /// /// - public void RemoveAt(int index) + /// This template, for chaining calls. + public LLamaTemplate RemoveAt(int index) { if (index < 0) throw new ArgumentOutOfRangeException(nameof(index), "Index must be greater than or equal to zero"); @@ -169,7 +184,10 @@ public sealed class LLamaTemplate Array.Copy(_messages, index + 1, _messages, index, Count - index); _messages[Count] = default; + + return this; } + #endregion /// /// Apply the template to the messages and write it into the output buffer @@ -192,7 +210,8 @@ public sealed class LLamaTemplate Array.Resize(ref _nativeChatMessages, _messages.Length); for (var i = 0; i < Count; i++) { - ref var m = ref _messages[i]; + ref var m = ref _messages[i]!; + Debug.Assert(m != null); totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length; // Pin byte arrays in place @@ -258,17 +277,24 @@ public sealed class LLamaTemplate } /// - /// A message that has been added to the template, contains role and content converted into UTF8 bytes. + /// A message that has been added to a template /// - private readonly record struct Message + public sealed class TextMessage { + /// + /// The "role" string for this message + /// public string Role { get; } + + /// + /// The text content of this message + /// public string Content { get; } - public ReadOnlyMemory RoleBytes { get; } - public ReadOnlyMemory ContentBytes { get; } + internal ReadOnlyMemory RoleBytes { get; } + internal ReadOnlyMemory ContentBytes { get; } - public Message(string role, string content, Dictionary> roleCache) + internal TextMessage(string role, string content, IDictionary> roleCache) { Role = role; Content = content; @@ -297,5 +323,16 @@ public sealed class LLamaTemplate Debug.Assert(contentArray.Length == encodedContentLength + 1); ContentBytes = contentArray; } + + /// + /// Deconstruct this message into role and content + /// + /// + /// + public void Deconstruct(out string role, out string content) + { + role = Role; + content = Content; + } } } \ No newline at end of file