diff --git a/LLama.Unittest/BeamTests.cs b/LLama.Unittest/BeamTests.cs index 99486088..1bbeae9c 100644 --- a/LLama.Unittest/BeamTests.cs +++ b/LLama.Unittest/BeamTests.cs @@ -47,14 +47,24 @@ public sealed class BeamTests for (var i = 0; i < state.Beams.Length; i++) { ref var view = ref state.Beams[i]; - var tokens = context.DeTokenize(view.Tokens.ToArray()); + + var decoder = new StreamingTokenDecoder(context); + decoder.AddRange(view.Tokens); + var tokens = decoder.Read(); + _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); } if (state.CommonPrefixLength > 0) { var view = state.Beams[0]; - result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray())); + + var decoder = new StreamingTokenDecoder(context); + decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength)); + var tokens = decoder.Read(); + + result.Append(tokens); + } }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index fd8d4189..74ab0f81 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -36,12 +36,12 @@ namespace LLama.Abstractions /// public int TopK { get; set; } - /// llama_eval + /// /// 1.0 = disabled /// public float TopP { get; set; } - /// llama_eval + /// /// 0.0 = disabled /// public float MinP { get; set; } diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs index f432a89f..8ee58a7f 100644 --- a/LLama/Batched/BatchedExecutor.cs +++ b/LLama/Batched/BatchedExecutor.cs @@ -55,6 +55,9 @@ public sealed class BatchedExecutor Epoch = 1; } + /// + /// Finalizer for BatchedExecutor + /// ~BatchedExecutor() { Dispose(); diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs index 6cf6e312..5b248f90 100644 --- a/LLama/Batched/Conversation.cs +++ b/LLama/Batched/Conversation.cs @@ -54,6 +54,9 @@ public sealed class Conversation _end = end; } + /// + /// Finalizer for Conversation + /// ~Conversation() { Dispose(); @@ -96,7 +99,7 @@ public sealed class Conversation AssertNotDisposed(); if (RequiresInference) - throw new CannotForkWhileRequiresInference(); + throw new CannotForkWhileRequiresInferenceException(); // Create a new conversation which references the current position in this one var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end) @@ -195,13 +198,13 @@ public sealed class Conversation /// Directly modify the KV cache of this conversation /// /// - /// Thrown if this method is called while == true + /// Thrown if this method is called while == true public void Modify(ModifyKvCache modifier) { AssertNotDisposed(); if (RequiresInference) - throw new CannotModifyWhileRequiresInference(); + throw new CannotModifyWhileRequiresInferenceException(); // do whatever the modification is _end = modifier.Invoke(_end, new KvAccessor(this)); diff --git a/LLama/Batched/Exceptions.cs b/LLama/Batched/Exceptions.cs index 1feb270c..b025202b 100644 --- a/LLama/Batched/Exceptions.cs +++ b/LLama/Batched/Exceptions.cs @@ -59,10 +59,10 @@ public class CannotSampleRequiresPromptException /// /// This exception is thrown when is called when = true /// -public class CannotForkWhileRequiresInference +public class CannotForkWhileRequiresInferenceException : ExperimentalBatchedExecutorException { - internal CannotForkWhileRequiresInference() + internal CannotForkWhileRequiresInferenceException() : base("Cannot `Fork()` a conversation while RequiresInference is true") { } @@ -71,10 +71,10 @@ public class CannotForkWhileRequiresInference /// /// This exception is thrown when is called when = true /// -public class CannotModifyWhileRequiresInference +public class CannotModifyWhileRequiresInferenceException : ExperimentalBatchedExecutorException { - internal CannotModifyWhileRequiresInference() + internal CannotModifyWhileRequiresInferenceException() : base("Cannot `Modify()` a conversation while RequiresInference is true") { } diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 9338d839..bd36f612 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -155,7 +155,7 @@ namespace LLama } /// - protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) + protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) { if (_embeds.Count > 0) { @@ -238,6 +238,8 @@ namespace LLama } } } + + return Task.CompletedTask; } /// diff --git a/LLama/Native/LLamaTokenType.cs b/LLama/Native/LLamaTokenType.cs index 171e782a..651df04a 100644 --- a/LLama/Native/LLamaTokenType.cs +++ b/LLama/Native/LLamaTokenType.cs @@ -1,12 +1,34 @@ namespace LLama.Native; +/// +/// Token Types +/// +/// C# equivalent of llama_token_get_type public enum LLamaTokenType { + /// + /// No specific type has been set for this token + /// LLAMA_TOKEN_TYPE_UNDEFINED = 0, + + /// + /// This is a "normal" token + /// LLAMA_TOKEN_TYPE_NORMAL = 1, + + /// + /// An "unknown" character/text token e.g. <unk> + /// LLAMA_TOKEN_TYPE_UNKNOWN = 2, + + /// + /// A special control token e.g. </s> + /// LLAMA_TOKEN_TYPE_CONTROL = 3, + LLAMA_TOKEN_TYPE_USER_DEFINED = 4, + LLAMA_TOKEN_TYPE_UNUSED = 5, + LLAMA_TOKEN_TYPE_BYTE = 6, } \ No newline at end of file diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 578cad40..902808f6 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -172,6 +172,11 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern uint llama_n_ctx(SafeLLamaContextHandle ctx); + /// + /// Get the batch size for this context + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern uint llama_n_batch(SafeLLamaContextHandle ctx); diff --git a/LLama/Native/RopeScalingType.cs b/LLama/Native/RopeScalingType.cs index 435932e8..21199d10 100644 --- a/LLama/Native/RopeScalingType.cs +++ b/LLama/Native/RopeScalingType.cs @@ -1,17 +1,30 @@ namespace LLama.Native { /// - /// RoPE scaling type. C# equivalent of llama_rope_scaling_type + /// RoPE scaling type. /// + /// C# equivalent of llama_rope_scaling_type public enum RopeScalingType : sbyte { + /// + /// No particular scaling type has been specified + /// LLAMA_ROPE_SCALING_UNSPECIFIED = -1, + /// + /// Do not apply any RoPE scaling + /// LLAMA_ROPE_SCALING_NONE = 0, + /// + /// Positional linear interpolation, as described by kaikendev: https://kaiokendev.github.io/til#extending-context-to-8k + /// LLAMA_ROPE_SCALING_LINEAR = 1, + /// + /// YaRN scaling: https://arxiv.org/pdf/2309.00071.pdf + /// LLAMA_ROPE_SCALING_YARN = 2, } } diff --git a/LLama/StreamingTokenDecoder.cs b/LLama/StreamingTokenDecoder.cs index 4c1ea58d..9252e532 100644 --- a/LLama/StreamingTokenDecoder.cs +++ b/LLama/StreamingTokenDecoder.cs @@ -142,22 +142,23 @@ namespace LLama /// Add all tokens in the given enumerable /// /// - public void AddRange(IEnumerable tokens) - { - foreach (var item in tokens) - Add(item); - } - - /// - /// Add all tokens in the given enumerable - /// - /// - public void AddRange(IEnumerable tokens) + public void AddRange(T tokens) + where T : IEnumerable { foreach (var item in tokens) Add((int)item); } + /// + /// Add all tokens in the given span + /// + /// + public void AddRange(ReadOnlySpan tokens) + { + foreach (var item in tokens) + Add(item); + } + /// /// Read all decoded characters and clear the buffer ///