Assorted small changes to clean up some code warnings
This commit is contained in:
parent
9bc129e252
commit
c7d0dc915a
|
@ -47,14 +47,24 @@ public sealed class BeamTests
|
||||||
for (var i = 0; i < state.Beams.Length; i++)
|
for (var i = 0; i < state.Beams.Length; i++)
|
||||||
{
|
{
|
||||||
ref var view = ref state.Beams[i];
|
ref var view = ref state.Beams[i];
|
||||||
var tokens = context.DeTokenize(view.Tokens.ToArray());
|
|
||||||
|
var decoder = new StreamingTokenDecoder(context);
|
||||||
|
decoder.AddRange(view.Tokens);
|
||||||
|
var tokens = decoder.Read();
|
||||||
|
|
||||||
_testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'");
|
_testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (state.CommonPrefixLength > 0)
|
if (state.CommonPrefixLength > 0)
|
||||||
{
|
{
|
||||||
var view = state.Beams[0];
|
var view = state.Beams[0];
|
||||||
result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray()));
|
|
||||||
|
var decoder = new StreamingTokenDecoder(context);
|
||||||
|
decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength));
|
||||||
|
var tokens = decoder.Read();
|
||||||
|
|
||||||
|
result.Append(tokens);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2));
|
}, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2));
|
||||||
|
|
|
@ -36,12 +36,12 @@ namespace LLama.Abstractions
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public int TopK { get; set; }
|
public int TopK { get; set; }
|
||||||
|
|
||||||
/// <summary>llama_eval
|
/// <summary>
|
||||||
/// 1.0 = disabled
|
/// 1.0 = disabled
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public float TopP { get; set; }
|
public float TopP { get; set; }
|
||||||
|
|
||||||
/// <summary>llama_eval
|
/// <summary>
|
||||||
/// 0.0 = disabled
|
/// 0.0 = disabled
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public float MinP { get; set; }
|
public float MinP { get; set; }
|
||||||
|
|
|
@ -55,6 +55,9 @@ public sealed class BatchedExecutor
|
||||||
Epoch = 1;
|
Epoch = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Finalizer for BatchedExecutor
|
||||||
|
/// </summary>
|
||||||
~BatchedExecutor()
|
~BatchedExecutor()
|
||||||
{
|
{
|
||||||
Dispose();
|
Dispose();
|
||||||
|
|
|
@ -54,6 +54,9 @@ public sealed class Conversation
|
||||||
_end = end;
|
_end = end;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Finalizer for Conversation
|
||||||
|
/// </summary>
|
||||||
~Conversation()
|
~Conversation()
|
||||||
{
|
{
|
||||||
Dispose();
|
Dispose();
|
||||||
|
@ -96,7 +99,7 @@ public sealed class Conversation
|
||||||
AssertNotDisposed();
|
AssertNotDisposed();
|
||||||
|
|
||||||
if (RequiresInference)
|
if (RequiresInference)
|
||||||
throw new CannotForkWhileRequiresInference();
|
throw new CannotForkWhileRequiresInferenceException();
|
||||||
|
|
||||||
// Create a new conversation which references the current position in this one
|
// Create a new conversation which references the current position in this one
|
||||||
var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end)
|
var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end)
|
||||||
|
@ -195,13 +198,13 @@ public sealed class Conversation
|
||||||
/// Directly modify the KV cache of this conversation
|
/// Directly modify the KV cache of this conversation
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="modifier"></param>
|
/// <param name="modifier"></param>
|
||||||
/// <exception cref="CannotModifyWhileRequiresInference">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception>
|
/// <exception cref="CannotModifyWhileRequiresInferenceException">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception>
|
||||||
public void Modify(ModifyKvCache modifier)
|
public void Modify(ModifyKvCache modifier)
|
||||||
{
|
{
|
||||||
AssertNotDisposed();
|
AssertNotDisposed();
|
||||||
|
|
||||||
if (RequiresInference)
|
if (RequiresInference)
|
||||||
throw new CannotModifyWhileRequiresInference();
|
throw new CannotModifyWhileRequiresInferenceException();
|
||||||
|
|
||||||
// do whatever the modification is
|
// do whatever the modification is
|
||||||
_end = modifier.Invoke(_end, new KvAccessor(this));
|
_end = modifier.Invoke(_end, new KvAccessor(this));
|
||||||
|
|
|
@ -59,10 +59,10 @@ public class CannotSampleRequiresPromptException
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// This exception is thrown when <see cref="Conversation.Fork"/> is called when <see cref="Conversation.RequiresInference"/> = true
|
/// This exception is thrown when <see cref="Conversation.Fork"/> is called when <see cref="Conversation.RequiresInference"/> = true
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public class CannotForkWhileRequiresInference
|
public class CannotForkWhileRequiresInferenceException
|
||||||
: ExperimentalBatchedExecutorException
|
: ExperimentalBatchedExecutorException
|
||||||
{
|
{
|
||||||
internal CannotForkWhileRequiresInference()
|
internal CannotForkWhileRequiresInferenceException()
|
||||||
: base("Cannot `Fork()` a conversation while RequiresInference is true")
|
: base("Cannot `Fork()` a conversation while RequiresInference is true")
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
@ -71,10 +71,10 @@ public class CannotForkWhileRequiresInference
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true
|
/// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public class CannotModifyWhileRequiresInference
|
public class CannotModifyWhileRequiresInferenceException
|
||||||
: ExperimentalBatchedExecutorException
|
: ExperimentalBatchedExecutorException
|
||||||
{
|
{
|
||||||
internal CannotModifyWhileRequiresInference()
|
internal CannotModifyWhileRequiresInferenceException()
|
||||||
: base("Cannot `Modify()` a conversation while RequiresInference is true")
|
: base("Cannot `Modify()` a conversation while RequiresInference is true")
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
|
@ -155,7 +155,7 @@ namespace LLama
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
|
protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
|
||||||
{
|
{
|
||||||
if (_embeds.Count > 0)
|
if (_embeds.Count > 0)
|
||||||
{
|
{
|
||||||
|
@ -238,6 +238,8 @@ namespace LLama
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return Task.CompletedTask;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
|
|
@ -1,12 +1,34 @@
|
||||||
namespace LLama.Native;
|
namespace LLama.Native;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Token Types
|
||||||
|
/// </summary>
|
||||||
|
/// <remarks>C# equivalent of llama_token_get_type</remarks>
|
||||||
public enum LLamaTokenType
|
public enum LLamaTokenType
|
||||||
{
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// No specific type has been set for this token
|
||||||
|
/// </summary>
|
||||||
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
|
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// This is a "normal" token
|
||||||
|
/// </summary>
|
||||||
LLAMA_TOKEN_TYPE_NORMAL = 1,
|
LLAMA_TOKEN_TYPE_NORMAL = 1,
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// An "unknown" character/text token e.g. <unk>
|
||||||
|
/// </summary>
|
||||||
LLAMA_TOKEN_TYPE_UNKNOWN = 2,
|
LLAMA_TOKEN_TYPE_UNKNOWN = 2,
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// A special control token e.g. </s>
|
||||||
|
/// </summary>
|
||||||
LLAMA_TOKEN_TYPE_CONTROL = 3,
|
LLAMA_TOKEN_TYPE_CONTROL = 3,
|
||||||
|
|
||||||
LLAMA_TOKEN_TYPE_USER_DEFINED = 4,
|
LLAMA_TOKEN_TYPE_USER_DEFINED = 4,
|
||||||
|
|
||||||
LLAMA_TOKEN_TYPE_UNUSED = 5,
|
LLAMA_TOKEN_TYPE_UNUSED = 5,
|
||||||
|
|
||||||
LLAMA_TOKEN_TYPE_BYTE = 6,
|
LLAMA_TOKEN_TYPE_BYTE = 6,
|
||||||
}
|
}
|
|
@ -172,6 +172,11 @@ namespace LLama.Native
|
||||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||||
public static extern uint llama_n_ctx(SafeLLamaContextHandle ctx);
|
public static extern uint llama_n_ctx(SafeLLamaContextHandle ctx);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Get the batch size for this context
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="ctx"></param>
|
||||||
|
/// <returns></returns>
|
||||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||||
public static extern uint llama_n_batch(SafeLLamaContextHandle ctx);
|
public static extern uint llama_n_batch(SafeLLamaContextHandle ctx);
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,30 @@
|
||||||
namespace LLama.Native
|
namespace LLama.Native
|
||||||
{
|
{
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// RoPE scaling type. C# equivalent of llama_rope_scaling_type
|
/// RoPE scaling type.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
/// <remarks>C# equivalent of llama_rope_scaling_type</remarks>
|
||||||
public enum RopeScalingType
|
public enum RopeScalingType
|
||||||
: sbyte
|
: sbyte
|
||||||
{
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// No particular scaling type has been specified
|
||||||
|
/// </summary>
|
||||||
LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
|
LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Do not apply any RoPE scaling
|
||||||
|
/// </summary>
|
||||||
LLAMA_ROPE_SCALING_NONE = 0,
|
LLAMA_ROPE_SCALING_NONE = 0,
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Positional linear interpolation, as described by kaikendev: https://kaiokendev.github.io/til#extending-context-to-8k
|
||||||
|
/// </summary>
|
||||||
LLAMA_ROPE_SCALING_LINEAR = 1,
|
LLAMA_ROPE_SCALING_LINEAR = 1,
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// YaRN scaling: https://arxiv.org/pdf/2309.00071.pdf
|
||||||
|
/// </summary>
|
||||||
LLAMA_ROPE_SCALING_YARN = 2,
|
LLAMA_ROPE_SCALING_YARN = 2,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -142,22 +142,23 @@ namespace LLama
|
||||||
/// Add all tokens in the given enumerable
|
/// Add all tokens in the given enumerable
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="tokens"></param>
|
/// <param name="tokens"></param>
|
||||||
public void AddRange(IEnumerable<int> tokens)
|
public void AddRange<T>(T tokens)
|
||||||
{
|
where T : IEnumerable<LLamaToken>
|
||||||
foreach (var item in tokens)
|
|
||||||
Add(item);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Add all tokens in the given enumerable
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="tokens"></param>
|
|
||||||
public void AddRange(IEnumerable<LLamaToken> tokens)
|
|
||||||
{
|
{
|
||||||
foreach (var item in tokens)
|
foreach (var item in tokens)
|
||||||
Add((int)item);
|
Add((int)item);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Add all tokens in the given span
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="tokens"></param>
|
||||||
|
public void AddRange(ReadOnlySpan<LLamaToken> tokens)
|
||||||
|
{
|
||||||
|
foreach (var item in tokens)
|
||||||
|
Add(item);
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Read all decoded characters and clear the buffer
|
/// Read all decoded characters and clear the buffer
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
|
Loading…
Reference in New Issue