Assorted small changes to clean up some code warnings

This commit is contained in:
Martin Evans 2024-02-17 23:07:10 +00:00
parent 9bc129e252
commit c7d0dc915a
10 changed files with 83 additions and 24 deletions

View File

@ -47,14 +47,24 @@ public sealed class BeamTests
for (var i = 0; i < state.Beams.Length; i++) for (var i = 0; i < state.Beams.Length; i++)
{ {
ref var view = ref state.Beams[i]; ref var view = ref state.Beams[i];
var tokens = context.DeTokenize(view.Tokens.ToArray());
var decoder = new StreamingTokenDecoder(context);
decoder.AddRange(view.Tokens);
var tokens = decoder.Read();
_testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'");
} }
if (state.CommonPrefixLength > 0) if (state.CommonPrefixLength > 0)
{ {
var view = state.Beams[0]; var view = state.Beams[0];
result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray()));
var decoder = new StreamingTokenDecoder(context);
decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength));
var tokens = decoder.Read();
result.Append(tokens);
} }
}, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2));

View File

@ -36,12 +36,12 @@ namespace LLama.Abstractions
/// </summary> /// </summary>
public int TopK { get; set; } public int TopK { get; set; }
/// <summary>llama_eval /// <summary>
/// 1.0 = disabled /// 1.0 = disabled
/// </summary> /// </summary>
public float TopP { get; set; } public float TopP { get; set; }
/// <summary>llama_eval /// <summary>
/// 0.0 = disabled /// 0.0 = disabled
/// </summary> /// </summary>
public float MinP { get; set; } public float MinP { get; set; }

View File

@ -55,6 +55,9 @@ public sealed class BatchedExecutor
Epoch = 1; Epoch = 1;
} }
/// <summary>
/// Finalizer for BatchedExecutor
/// </summary>
~BatchedExecutor() ~BatchedExecutor()
{ {
Dispose(); Dispose();

View File

@ -54,6 +54,9 @@ public sealed class Conversation
_end = end; _end = end;
} }
/// <summary>
/// Finalizer for Conversation
/// </summary>
~Conversation() ~Conversation()
{ {
Dispose(); Dispose();
@ -96,7 +99,7 @@ public sealed class Conversation
AssertNotDisposed(); AssertNotDisposed();
if (RequiresInference) if (RequiresInference)
throw new CannotForkWhileRequiresInference(); throw new CannotForkWhileRequiresInferenceException();
// Create a new conversation which references the current position in this one // Create a new conversation which references the current position in this one
var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end) var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end)
@ -195,13 +198,13 @@ public sealed class Conversation
/// Directly modify the KV cache of this conversation /// Directly modify the KV cache of this conversation
/// </summary> /// </summary>
/// <param name="modifier"></param> /// <param name="modifier"></param>
/// <exception cref="CannotModifyWhileRequiresInference">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception> /// <exception cref="CannotModifyWhileRequiresInferenceException">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception>
public void Modify(ModifyKvCache modifier) public void Modify(ModifyKvCache modifier)
{ {
AssertNotDisposed(); AssertNotDisposed();
if (RequiresInference) if (RequiresInference)
throw new CannotModifyWhileRequiresInference(); throw new CannotModifyWhileRequiresInferenceException();
// do whatever the modification is // do whatever the modification is
_end = modifier.Invoke(_end, new KvAccessor(this)); _end = modifier.Invoke(_end, new KvAccessor(this));

View File

@ -59,10 +59,10 @@ public class CannotSampleRequiresPromptException
/// <summary> /// <summary>
/// This exception is thrown when <see cref="Conversation.Fork"/> is called when <see cref="Conversation.RequiresInference"/> = true /// This exception is thrown when <see cref="Conversation.Fork"/> is called when <see cref="Conversation.RequiresInference"/> = true
/// </summary> /// </summary>
public class CannotForkWhileRequiresInference public class CannotForkWhileRequiresInferenceException
: ExperimentalBatchedExecutorException : ExperimentalBatchedExecutorException
{ {
internal CannotForkWhileRequiresInference() internal CannotForkWhileRequiresInferenceException()
: base("Cannot `Fork()` a conversation while RequiresInference is true") : base("Cannot `Fork()` a conversation while RequiresInference is true")
{ {
} }
@ -71,10 +71,10 @@ public class CannotForkWhileRequiresInference
/// <summary> /// <summary>
/// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true /// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true
/// </summary> /// </summary>
public class CannotModifyWhileRequiresInference public class CannotModifyWhileRequiresInferenceException
: ExperimentalBatchedExecutorException : ExperimentalBatchedExecutorException
{ {
internal CannotModifyWhileRequiresInference() internal CannotModifyWhileRequiresInferenceException()
: base("Cannot `Modify()` a conversation while RequiresInference is true") : base("Cannot `Modify()` a conversation while RequiresInference is true")
{ {
} }

View File

@ -155,7 +155,7 @@ namespace LLama
} }
/// <inheritdoc /> /// <inheritdoc />
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{ {
if (_embeds.Count > 0) if (_embeds.Count > 0)
{ {
@ -238,6 +238,8 @@ namespace LLama
} }
} }
} }
return Task.CompletedTask;
} }
/// <summary> /// <summary>

View File

@ -1,12 +1,34 @@
namespace LLama.Native; namespace LLama.Native;
/// <summary>
/// Token Types
/// </summary>
/// <remarks>C# equivalent of llama_token_get_type</remarks>
public enum LLamaTokenType public enum LLamaTokenType
{ {
/// <summary>
/// No specific type has been set for this token
/// </summary>
LLAMA_TOKEN_TYPE_UNDEFINED = 0, LLAMA_TOKEN_TYPE_UNDEFINED = 0,
/// <summary>
/// This is a "normal" token
/// </summary>
LLAMA_TOKEN_TYPE_NORMAL = 1, LLAMA_TOKEN_TYPE_NORMAL = 1,
/// <summary>
/// An "unknown" character/text token e.g. &lt;unk&gt;
/// </summary>
LLAMA_TOKEN_TYPE_UNKNOWN = 2, LLAMA_TOKEN_TYPE_UNKNOWN = 2,
/// <summary>
/// A special control token e.g. &lt;/s&gt;
/// </summary>
LLAMA_TOKEN_TYPE_CONTROL = 3, LLAMA_TOKEN_TYPE_CONTROL = 3,
LLAMA_TOKEN_TYPE_USER_DEFINED = 4, LLAMA_TOKEN_TYPE_USER_DEFINED = 4,
LLAMA_TOKEN_TYPE_UNUSED = 5, LLAMA_TOKEN_TYPE_UNUSED = 5,
LLAMA_TOKEN_TYPE_BYTE = 6, LLAMA_TOKEN_TYPE_BYTE = 6,
} }

View File

@ -172,6 +172,11 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern uint llama_n_ctx(SafeLLamaContextHandle ctx); public static extern uint llama_n_ctx(SafeLLamaContextHandle ctx);
/// <summary>
/// Get the batch size for this context
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern uint llama_n_batch(SafeLLamaContextHandle ctx); public static extern uint llama_n_batch(SafeLLamaContextHandle ctx);

View File

@ -1,17 +1,30 @@
namespace LLama.Native namespace LLama.Native
{ {
/// <summary> /// <summary>
/// RoPE scaling type. C# equivalent of llama_rope_scaling_type /// RoPE scaling type.
/// </summary> /// </summary>
/// <remarks>C# equivalent of llama_rope_scaling_type</remarks>
public enum RopeScalingType public enum RopeScalingType
: sbyte : sbyte
{ {
/// <summary>
/// No particular scaling type has been specified
/// </summary>
LLAMA_ROPE_SCALING_UNSPECIFIED = -1, LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
/// <summary>
/// Do not apply any RoPE scaling
/// </summary>
LLAMA_ROPE_SCALING_NONE = 0, LLAMA_ROPE_SCALING_NONE = 0,
/// <summary>
/// Positional linear interpolation, as described by kaikendev: https://kaiokendev.github.io/til#extending-context-to-8k
/// </summary>
LLAMA_ROPE_SCALING_LINEAR = 1, LLAMA_ROPE_SCALING_LINEAR = 1,
/// <summary>
/// YaRN scaling: https://arxiv.org/pdf/2309.00071.pdf
/// </summary>
LLAMA_ROPE_SCALING_YARN = 2, LLAMA_ROPE_SCALING_YARN = 2,
} }
} }

View File

@ -142,22 +142,23 @@ namespace LLama
/// Add all tokens in the given enumerable /// Add all tokens in the given enumerable
/// </summary> /// </summary>
/// <param name="tokens"></param> /// <param name="tokens"></param>
public void AddRange(IEnumerable<int> tokens) public void AddRange<T>(T tokens)
{ where T : IEnumerable<LLamaToken>
foreach (var item in tokens)
Add(item);
}
/// <summary>
/// Add all tokens in the given enumerable
/// </summary>
/// <param name="tokens"></param>
public void AddRange(IEnumerable<LLamaToken> tokens)
{ {
foreach (var item in tokens) foreach (var item in tokens)
Add((int)item); Add((int)item);
} }
/// <summary>
/// Add all tokens in the given span
/// </summary>
/// <param name="tokens"></param>
public void AddRange(ReadOnlySpan<LLamaToken> tokens)
{
foreach (var item in tokens)
Add(item);
}
/// <summary> /// <summary>
/// Read all decoded characters and clear the buffer /// Read all decoded characters and clear the buffer
/// </summary> /// </summary>