Assorted cleanup leftover after the huge change in the last PR (comments, syntax style, etc)

This commit is contained in:
Martin Evans 2023-10-19 00:26:30 +01:00
parent d8434ea9d6
commit 9daf586ba8
22 changed files with 260 additions and 204 deletions

View File

@ -34,11 +34,6 @@ namespace LLama.Abstractions
/// </summary>
string ModelPath { get; set; }
/// <summary>
/// Number of threads (-1 = autodetect) (n_threads)
/// </summary>
uint? Threads { get; set; }
/// <summary>
/// how split tensors should be distributed across GPUs
/// </summary>

View File

@ -12,7 +12,6 @@ namespace LLama.Common
public class FixedSizeQueue<T>
: IEnumerable<T>
{
private readonly int _maxSize;
private readonly List<T> _storage;
internal IReadOnlyList<T> Items => _storage;
@ -25,7 +24,7 @@ namespace LLama.Common
/// <summary>
/// Maximum number of items allowed in this queue
/// </summary>
public int Capacity => _maxSize;
public int Capacity { get; }
/// <summary>
/// Create a new queue
@ -33,7 +32,7 @@ namespace LLama.Common
/// <param name="size">the maximum number of items to store in this queue</param>
public FixedSizeQueue(int size)
{
_maxSize = size;
Capacity = size;
_storage = new();
}
@ -52,11 +51,11 @@ namespace LLama.Common
#endif
// Size of "data" is unknown, copy it all into a list
_maxSize = size;
Capacity = size;
_storage = new List<T>(data);
// Now check if that list is a valid size.
if (_storage.Count > _maxSize)
if (_storage.Count > Capacity)
throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values.");
}
@ -81,7 +80,7 @@ namespace LLama.Common
public void Enqueue(T item)
{
_storage.Add(item);
if(_storage.Count >= _maxSize)
if(_storage.Count >= Capacity)
{
_storage.RemoveAt(0);
}

View File

@ -40,11 +40,11 @@ namespace LLama.Common
/// <summary>
/// Use mlock to keep model in memory (use_mlock)
/// </summary>
public bool UseMemoryLock { get; set; } = false;
public bool UseMemoryLock { get; set; }
/// <summary>
/// Compute perplexity over the prompt (perplexity)
/// </summary>
public bool Perplexity { get; set; } = false;
public bool Perplexity { get; set; }
/// <summary>
/// Model path (model)
/// </summary>
@ -79,7 +79,7 @@ namespace LLama.Common
/// Whether to use embedding mode. (embedding) Note that if this is set to true,
/// The LLamaModel won't produce text response anymore.
/// </summary>
public bool EmbeddingMode { get; set; } = false;
public bool EmbeddingMode { get; set; }
/// <summary>
/// how split tensors should be distributed across GPUs

View File

@ -58,7 +58,7 @@ public class GrammarUnexpectedEndOfInput
: GrammarFormatException
{
internal GrammarUnexpectedEndOfInput()
: base($"Unexpected end of input")
: base("Unexpected end of input")
{
}
}

View File

@ -1,19 +1,20 @@
using System;
namespace LLama.Exceptions
namespace LLama.Exceptions;
/// <summary>
/// Base class for LLamaSharp runtime errors (i.e. errors produced by llama.cpp, converted into exceptions)
/// </summary>
public class RuntimeError
: Exception
{
public class RuntimeError
: Exception
/// <summary>
/// Create a new RuntimeError
/// </summary>
/// <param name="message"></param>
public RuntimeError(string message)
: base(message)
{
public RuntimeError()
{
}
public RuntimeError(string message)
: base(message)
{
}
}
}
}

View File

@ -1,5 +1,4 @@
using System;
using System.Collections.Generic;
using System.Text;
namespace LLama.Extensions;

View File

@ -1,26 +1,23 @@
using System.Collections.Generic;
namespace LLama.Extensions;
namespace LLama.Extensions
/// <summary>
/// Extensions to the KeyValuePair struct
/// </summary>
internal static class KeyValuePairExtensions
{
/// <summary>
/// Extensions to the KeyValuePair struct
/// </summary>
internal static class KeyValuePairExtensions
{
#if NETSTANDARD2_0
/// <summary>
/// Deconstruct a KeyValuePair into it's constituent parts.
/// </summary>
/// <param name="pair">The KeyValuePair to deconstruct</param>
/// <param name="first">First element, the Key</param>
/// <param name="second">Second element, the Value</param>
/// <typeparam name="TKey">Type of the Key</typeparam>
/// <typeparam name="TValue">Type of the Value</typeparam>
public static void Deconstruct<TKey, TValue>(this KeyValuePair<TKey, TValue> pair, out TKey first, out TValue second)
{
first = pair.Key;
second = pair.Value;
}
#endif
/// <summary>
/// Deconstruct a KeyValuePair into it's constituent parts.
/// </summary>
/// <param name="pair">The KeyValuePair to deconstruct</param>
/// <param name="first">First element, the Key</param>
/// <param name="second">Second element, the Value</param>
/// <typeparam name="TKey">Type of the Key</typeparam>
/// <typeparam name="TValue">Type of the Value</typeparam>
public static void Deconstruct<TKey, TValue>(this System.Collections.Generic.KeyValuePair<TKey, TValue> pair, out TKey first, out TValue second)
{
first = pair.Key;
second = pair.Value;
}
}
#endif
}

View File

@ -17,7 +17,7 @@ namespace LLama.Grammars
{
// NOTE: assumes valid utf8 (but checks for overrun)
// copied from llama.cpp
private uint DecodeUTF8(ref ReadOnlySpan<byte> src)
private static uint DecodeUTF8(ref ReadOnlySpan<byte> src)
{
int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
@ -40,46 +40,12 @@ namespace LLama.Grammars
return value;
}
private uint GetSymbolId(ParseState state, ReadOnlySpan<byte> src, int len)
{
uint nextId = (uint)state.SymbolIds.Count;
string key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray());
if (state.SymbolIds.TryGetValue(key, out uint existingId))
{
return existingId;
}
else
{
state.SymbolIds[key] = nextId;
return nextId;
}
}
private uint GenerateSymbolId(ParseState state, string baseName)
{
uint nextId = (uint)state.SymbolIds.Count;
string key = $"{baseName}_{nextId}";
state.SymbolIds[key] = nextId;
return nextId;
}
private void AddRule(ParseState state, uint ruleId, List<LLamaGrammarElement> rule)
{
while (state.Rules.Count <= ruleId)
{
state.Rules.Add(new List<LLamaGrammarElement>());
}
state.Rules[(int)ruleId] = rule;
}
private bool IsWordChar(byte c)
private static bool IsWordChar(byte c)
{
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
}
private uint ParseHex(ref ReadOnlySpan<byte> src, int size)
private static uint ParseHex(ref ReadOnlySpan<byte> src, int size)
{
int pos = 0;
int end = size;
@ -115,7 +81,7 @@ namespace LLama.Grammars
return value;
}
private ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk)
private static ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk)
{
int pos = 0;
while (pos < src.Length &&
@ -137,7 +103,7 @@ namespace LLama.Grammars
return src.Slice(pos);
}
private ReadOnlySpan<byte> ParseName(ReadOnlySpan<byte> src)
private static ReadOnlySpan<byte> ParseName(ReadOnlySpan<byte> src)
{
int pos = 0;
while (pos < src.Length && IsWordChar(src[pos]))
@ -151,7 +117,7 @@ namespace LLama.Grammars
return src.Slice(pos);
}
private uint ParseChar(ref ReadOnlySpan<byte> src)
private static uint ParseChar(ref ReadOnlySpan<byte> src)
{
if (src[0] == '\\')
{
@ -235,7 +201,7 @@ namespace LLama.Grammars
else if (IsWordChar(pos[0])) // rule reference
{
var nameEnd = ParseName(pos);
uint refRuleId = GetSymbolId(state, pos, nameEnd.Length);
uint refRuleId = state.GetSymbolId(pos, nameEnd.Length);
pos = ParseSpace(nameEnd, isNested);
lastSymStart = outElements.Count;
outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId));
@ -244,7 +210,7 @@ namespace LLama.Grammars
{
// parse nested alternates into synthesized rule
pos = ParseSpace(pos.Slice(1), true);
uint subRuleId = GenerateSymbolId(state, ruleName);
uint subRuleId = state.GenerateSymbolId(ruleName);
pos = ParseAlternates(state, pos, ruleName, subRuleId, true);
lastSymStart = outElements.Count;
// output reference to synthesized rule
@ -263,7 +229,7 @@ namespace LLama.Grammars
// S* --> S' ::= S S' |
// S+ --> S' ::= S S' | S
// S? --> S' ::= S |
uint subRuleId = GenerateSymbolId(state, ruleName);
uint subRuleId = state.GenerateSymbolId(ruleName);
List<LLamaGrammarElement> subRule = new List<LLamaGrammarElement>();
@ -287,7 +253,7 @@ namespace LLama.Grammars
subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));
AddRule(state, subRuleId, subRule);
state.AddRule(subRuleId, subRule);
// in original rule, replace previous symbol with reference to generated rule
outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart);
@ -323,7 +289,7 @@ namespace LLama.Grammars
}
rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));
AddRule(state, ruleId, rule);
state.AddRule(ruleId, rule);
return pos;
}
@ -333,7 +299,7 @@ namespace LLama.Grammars
ReadOnlySpan<byte> nameEnd = ParseName(src);
ReadOnlySpan<byte> pos = ParseSpace(nameEnd, false);
int nameLen = src.Length - nameEnd.Length;
uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), 0);
uint ruleId = state.GetSymbolId(src.Slice(0, nameLen), 0);
string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray());
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '='))
@ -393,6 +359,40 @@ namespace LLama.Grammars
{
public SortedDictionary<string, uint> SymbolIds { get; } = new();
public List<List<LLamaGrammarElement>> Rules { get; } = new();
public uint GetSymbolId(ReadOnlySpan<byte> src, int len)
{
var nextId = (uint)SymbolIds.Count;
var key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray());
if (SymbolIds.TryGetValue(key, out uint existingId))
{
return existingId;
}
else
{
SymbolIds[key] = nextId;
return nextId;
}
}
public uint GenerateSymbolId(string baseName)
{
var nextId = (uint)SymbolIds.Count;
var key = $"{baseName}_{nextId}";
SymbolIds[key] = nextId;
return nextId;
}
public void AddRule(uint ruleId, List<LLamaGrammarElement> rule)
{
while (Rules.Count <= ruleId)
{
Rules.Add(new List<LLamaGrammarElement>());
}
Rules[(int)ruleId] = rule;
}
}
}
}

View File

@ -112,7 +112,6 @@ namespace LLama.Grammars
case LLamaGrammarElementType.CHAR_ALT:
PrintGrammarChar(output, elem.Value);
break;
}
if (elem.IsCharElement())

View File

@ -23,23 +23,21 @@ namespace LLama
: IDisposable
{
private readonly ILogger? _logger;
private readonly Encoding _encoding;
private readonly SafeLLamaContextHandle _ctx;
/// <summary>
/// Total number of tokens in vocabulary of this model
/// </summary>
public int VocabCount => _ctx.VocabCount;
public int VocabCount => NativeHandle.VocabCount;
/// <summary>
/// Total number of tokens in the context
/// </summary>
public int ContextSize => _ctx.ContextSize;
public int ContextSize => NativeHandle.ContextSize;
/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize => _ctx.EmbeddingSize;
public int EmbeddingSize => NativeHandle.EmbeddingSize;
/// <summary>
/// The context params set for this context
@ -50,20 +48,20 @@ namespace LLama
/// The native handle, which is used to be passed to the native APIs
/// </summary>
/// <remarks>Be careful how you use this!</remarks>
public SafeLLamaContextHandle NativeHandle => _ctx;
public SafeLLamaContextHandle NativeHandle { get; }
/// <summary>
/// The encoding set for this model to deal with text input.
/// </summary>
public Encoding Encoding => _encoding;
public Encoding Encoding { get; }
internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null)
{
Params = @params;
_logger = logger;
_encoding = @params.Encoding;
_ctx = nativeContext;
Encoding = @params.Encoding;
NativeHandle = nativeContext;
}
/// <summary>
@ -81,10 +79,10 @@ namespace LLama
Params = @params;
_logger = logger;
_encoding = @params.Encoding;
Encoding = @params.Encoding;
@params.ToLlamaContextParams(out var lparams);
_ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);
NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);
}
/// <summary>
@ -96,7 +94,7 @@ namespace LLama
/// <returns></returns>
public llama_token[] Tokenize(string text, bool addBos = true, bool special = false)
{
return _ctx.Tokenize(text, addBos, special, _encoding);
return NativeHandle.Tokenize(text, addBos, special, Encoding);
}
/// <summary>
@ -108,7 +106,7 @@ namespace LLama
{
var sb = new StringBuilder();
foreach (var token in tokens)
_ctx.TokenToString(token, _encoding, sb);
NativeHandle.TokenToString(token, Encoding, sb);
return sb.ToString();
}
@ -124,7 +122,7 @@ namespace LLama
File.Delete(filename);
// Estimate size of state to write to disk, this is always equal to or greater than the actual size
var estimatedStateSize = (long)NativeApi.llama_get_state_size(_ctx);
var estimatedStateSize = (long)NativeApi.llama_get_state_size(NativeHandle);
// Map the file and write the bytes directly to it. This saves copying the bytes into a C# array
long writtenBytes;
@ -135,7 +133,7 @@ namespace LLama
{
byte* ptr = null;
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
writtenBytes = (long)NativeApi.llama_copy_state_data(_ctx, ptr);
writtenBytes = (long)NativeApi.llama_copy_state_data(NativeHandle, ptr);
view.SafeMemoryMappedViewHandle.ReleasePointer();
}
}
@ -151,14 +149,14 @@ namespace LLama
/// <returns></returns>
public State GetState()
{
var stateSize = _ctx.GetStateSize();
var stateSize = NativeHandle.GetStateSize();
// Allocate a chunk of memory large enough to hold the entire state
var memory = Marshal.AllocHGlobal((nint)stateSize);
try
{
// Copy the state data into memory, discover the actual size required
var actualSize = _ctx.GetState(memory, stateSize);
var actualSize = NativeHandle.GetState(memory, stateSize);
// Shrink to size
memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);
@ -193,7 +191,7 @@ namespace LLama
{
byte* ptr = null;
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
NativeApi.llama_set_state_data(_ctx, ptr);
NativeApi.llama_set_state_data(NativeHandle, ptr);
view.SafeMemoryMappedViewHandle.ReleasePointer();
}
}
@ -208,7 +206,7 @@ namespace LLama
{
unsafe
{
_ctx.SetState((byte*)state.DangerousGetHandle().ToPointer());
NativeHandle.SetState((byte*)state.DangerousGetHandle().ToPointer());
}
}
@ -235,13 +233,13 @@ namespace LLama
if (grammar != null)
{
SamplingApi.llama_sample_grammar(_ctx, candidates, grammar);
SamplingApi.llama_sample_grammar(NativeHandle, candidates, grammar);
}
if (temperature <= 0)
{
// Greedy sampling
id = SamplingApi.llama_sample_token_greedy(_ctx, candidates);
id = SamplingApi.llama_sample_token_greedy(NativeHandle, candidates);
}
else
{
@ -250,23 +248,23 @@ namespace LLama
if (mirostat == MirostatType.Mirostat)
{
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(NativeHandle, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
}
else if (mirostat == MirostatType.Mirostat2)
{
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mu);
SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(NativeHandle, candidates, mirostatTau, mirostatEta, ref mu);
}
else
{
// Temperature sampling
SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1);
SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1);
SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1);
SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1);
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token(_ctx, candidates);
SamplingApi.llama_sample_top_k(NativeHandle, candidates, topK, 1);
SamplingApi.llama_sample_tail_free(NativeHandle, candidates, tfsZ, 1);
SamplingApi.llama_sample_typical(NativeHandle, candidates, typicalP, 1);
SamplingApi.llama_sample_top_p(NativeHandle, candidates, topP, 1);
SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
id = SamplingApi.llama_sample_token(NativeHandle, candidates);
}
}
mirostat_mu = mu;
@ -274,7 +272,7 @@ namespace LLama
if (grammar != null)
{
NativeApi.llama_grammar_accept_token(_ctx, grammar, id);
NativeApi.llama_grammar_accept_token(NativeHandle, grammar, id);
}
return id;
@ -295,7 +293,7 @@ namespace LLama
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
{
var logits = _ctx.GetLogits();
var logits = NativeHandle.GetLogits();
// Apply params.logit_bias map
if (logitBias is not null)
@ -305,7 +303,7 @@ namespace LLama
}
// Save the newline logit value
var nl_token = NativeApi.llama_token_nl(_ctx);
var nl_token = NativeApi.llama_token_nl(NativeHandle);
var nl_logit = logits[nl_token];
// Convert logits into token candidates
@ -316,8 +314,8 @@ namespace LLama
var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray();
// Apply penalties to candidates
SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, last_n_array, repeatPenalty);
SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, last_n_array, alphaFrequency, alphaPresence);
SamplingApi.llama_sample_repetition_penalty(NativeHandle, candidates_p, last_n_array, repeatPenalty);
SamplingApi.llama_sample_frequency_and_presence_penalties(NativeHandle, candidates_p, last_n_array, alphaFrequency, alphaPresence);
// Restore newline token logit value if necessary
if (!penalizeNL)
@ -408,9 +406,9 @@ namespace LLama
n_eval = (int)Params.BatchSize;
}
if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount))
if (!NativeHandle.Eval(tokens.Slice(i, n_eval), pastTokensCount))
{
_logger?.LogError($"[LLamaContext] Failed to eval.");
_logger?.LogError("[LLamaContext] Failed to eval.");
throw new RuntimeError("Failed to eval.");
}
@ -443,7 +441,7 @@ namespace LLama
/// <inheritdoc />
public void Dispose()
{
_ctx.Dispose();
NativeHandle.Dispose();
}
/// <summary>

View File

@ -18,11 +18,22 @@ namespace LLama
/// </summary>
public int EmbeddingSize => _ctx.EmbeddingSize;
/// <summary>
/// Create a new embedder (loading temporary weights)
/// </summary>
/// <param name="allParams"></param>
[Obsolete("Preload LLamaWeights and use the constructor which accepts them")]
public LLamaEmbedder(ILLamaParams allParams)
: this(allParams, allParams)
{
}
/// <summary>
/// Create a new embedder (loading temporary weights)
/// </summary>
/// <param name="modelParams"></param>
/// <param name="contextParams"></param>
[Obsolete("Preload LLamaWeights and use the constructor which accepts them")]
public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams)
{
using var weights = LLamaWeights.LoadFromFile(modelParams);
@ -31,6 +42,11 @@ namespace LLama
_ctx = weights.CreateContext(contextParams);
}
/// <summary>
/// Create a new embedder, using the given LLamaWeights
/// </summary>
/// <param name="weights"></param>
/// <param name="params"></param>
public LLamaEmbedder(LLamaWeights weights, IContextParams @params)
{
@params.EmbeddingMode = true;

View File

@ -114,7 +114,7 @@ namespace LLama
}
else
{
_logger?.LogWarning($"[LLamaExecutor] Session file does not exist, will create");
_logger?.LogWarning("[LLamaExecutor] Session file does not exist, will create");
}
_n_matching_session_tokens = 0;

View File

@ -18,10 +18,10 @@ namespace LLama
/// </summary>
public class InstructExecutor : StatefulExecutorBase
{
bool _is_prompt_run = true;
string _instructionPrefix;
llama_token[] _inp_pfx;
llama_token[] _inp_sfx;
private bool _is_prompt_run = true;
private readonly string _instructionPrefix;
private llama_token[] _inp_pfx;
private llama_token[] _inp_sfx;
/// <summary>
///

View File

@ -80,6 +80,7 @@ namespace LLama
return true;
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
case LLamaFtype.LLAMA_FTYPE_GUESSED:
default:
return false;
}

View File

@ -11,13 +11,11 @@ namespace LLama
public sealed class LLamaWeights
: IDisposable
{
private readonly SafeLlamaModelHandle _weights;
/// <summary>
/// The native handle, which is used in the native APIs
/// </summary>
/// <remarks>Be careful how you use this!</remarks>
public SafeLlamaModelHandle NativeHandle => _weights;
public SafeLlamaModelHandle NativeHandle { get; }
/// <summary>
/// Total number of tokens in vocabulary of this model
@ -46,7 +44,7 @@ namespace LLama
internal LLamaWeights(SafeLlamaModelHandle weights)
{
_weights = weights;
NativeHandle = weights;
}
/// <summary>
@ -66,7 +64,7 @@ namespace LLama
if (adapter.Scale <= 0)
continue;
weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase, @params.Threads);
weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase);
}
return new LLamaWeights(weights);
@ -75,7 +73,7 @@ namespace LLama
/// <inheritdoc />
public void Dispose()
{
_weights.Dispose();
NativeHandle.Dispose();
}
/// <summary>

View File

@ -4,11 +4,18 @@ namespace LLama.Native;
using llama_token = Int32;
/// <summary>
/// Input data for llama_decode. A llama_batch object can contain input about one or many sequences.
/// </summary>
public sealed class LLamaBatchSafeHandle
: SafeLLamaHandleBase
{
private readonly int _embd;
public LLamaNativeBatch Batch { get; private set; }
/// <summary>
/// Get the native llama_batch struct
/// </summary>
public LLamaNativeBatch NativeBatch { get; private set; }
/// <summary>
/// the token ids of the input (used when embd is NULL)
@ -22,7 +29,7 @@ public sealed class LLamaBatchSafeHandle
if (_embd != 0)
return new Span<int>(null, 0);
else
return new Span<int>(Batch.token, Batch.n_tokens);
return new Span<int>(NativeBatch.token, NativeBatch.n_tokens);
}
}
}
@ -37,10 +44,10 @@ public sealed class LLamaBatchSafeHandle
unsafe
{
// If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float)
/// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
if (_embd != 0)
return new Span<llama_token>(Batch.embd, Batch.n_tokens * _embd);
return new Span<llama_token>(NativeBatch.embd, NativeBatch.n_tokens * _embd);
else
return new Span<llama_token>(null, 0);
}
@ -56,7 +63,7 @@ public sealed class LLamaBatchSafeHandle
{
unsafe
{
return new Span<LLamaPos>(Batch.pos, Batch.n_tokens);
return new Span<LLamaPos>(NativeBatch.pos, NativeBatch.n_tokens);
}
}
}
@ -70,7 +77,7 @@ public sealed class LLamaBatchSafeHandle
{
unsafe
{
return new Span<LLamaSeqId>(Batch.seq_id, Batch.n_tokens);
return new Span<LLamaSeqId>(NativeBatch.seq_id, NativeBatch.n_tokens);
}
}
}
@ -84,22 +91,40 @@ public sealed class LLamaBatchSafeHandle
{
unsafe
{
return new Span<byte>(Batch.logits, Batch.n_tokens);
return new Span<byte>(NativeBatch.logits, NativeBatch.n_tokens);
}
}
}
public LLamaBatchSafeHandle(int n_tokens, int embd)
/// <summary>
/// Create a safe handle owning a `LLamaNativeBatch`
/// </summary>
/// <param name="batch"></param>
/// <param name="embd"></param>
public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd)
: base((nint)1)
{
_embd = embd;
Batch = NativeApi.llama_batch_init(n_tokens, embd);
NativeBatch = batch;
}
/// <summary>
/// Call `llama_batch_init` and create a new batch
/// </summary>
/// <param name="n_tokens"></param>
/// <param name="embd"></param>
/// <returns></returns>
public static LLamaBatchSafeHandle Create(int n_tokens, int embd)
{
var batch = NativeApi.llama_batch_init(n_tokens, embd);
return new LLamaBatchSafeHandle(batch, embd);
}
/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llama_batch_free(Batch);
Batch = default;
NativeApi.llama_batch_free(NativeBatch);
NativeBatch = default;
SetHandle(IntPtr.Zero);
return true;
}

View File

@ -45,7 +45,7 @@ namespace LLama.Native
/// CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
/// </summary>
CHAR_ALT = 6,
};
}
/// <summary>
/// An element of a grammar

View File

@ -1,15 +1,26 @@
namespace LLama.Native;
public record struct LLamaPos
/// <summary>
/// Indicates position in a sequence
/// </summary>
public readonly record struct LLamaPos(int Value)
{
public int Value;
public LLamaPos(int value)
{
Value = value;
}
/// <summary>
/// The raw value
/// </summary>
public readonly int Value = Value;
/// <summary>
/// Convert a LLamaPos into an integer (extract the raw value)
/// </summary>
/// <param name="pos"></param>
/// <returns></returns>
public static explicit operator int(LLamaPos pos) => pos.Value;
/// <summary>
/// Convert an integer into a LLamaPos
/// </summary>
/// <param name="value"></param>
/// <returns></returns>
public static implicit operator LLamaPos(int value) => new(value);
}

View File

@ -1,15 +1,26 @@
namespace LLama.Native;
public record struct LLamaSeqId
/// <summary>
/// ID for a sequence in a batch
/// </summary>
/// <param name="Value"></param>
public record struct LLamaSeqId(int Value)
{
public int Value;
public LLamaSeqId(int value)
{
Value = value;
}
/// <summary>
/// The raw value
/// </summary>
public int Value = Value;
/// <summary>
/// Convert a LLamaSeqId into an integer (extract the raw value)
/// </summary>
/// <param name="pos"></param>
public static explicit operator int(LLamaSeqId pos) => pos.Value;
/// <summary>
/// Convert an integer into a LLamaSeqId
/// </summary>
/// <param name="value"></param>
/// <returns></returns>
public static explicit operator LLamaSeqId(int value) => new(value);
}

View File

@ -1,28 +1,28 @@
using System.Runtime.InteropServices;
namespace LLama.Native
{
[StructLayout(LayoutKind.Sequential)]
public struct LLamaTokenData
{
/// <summary>
/// token id
/// </summary>
public int id;
/// <summary>
/// log-odds of the token
/// </summary>
public float logit;
/// <summary>
/// probability of the token
/// </summary>
public float p;
namespace LLama.Native;
public LLamaTokenData(int id, float logit, float p)
{
this.id = id;
this.logit = logit;
this.p = p;
}
}
}
/// <summary>
/// A single token along with probability of this token being selected
/// </summary>
/// <param name="id"></param>
/// <param name="logit"></param>
/// <param name="p"></param>
[StructLayout(LayoutKind.Sequential)]
public record struct LLamaTokenData(int id, float logit, float p)
{
/// <summary>
/// token id
/// </summary>
public int id = id;
/// <summary>
/// log-odds of the token
/// </summary>
public float logit = logit;
/// <summary>
/// probability of the token
/// </summary>
public float p = p;
}

View File

@ -1,7 +1,5 @@
using System;
using System.Buffers;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
@ -212,9 +210,17 @@ namespace LLama.Native
}
}
/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public int Decode(LLamaBatchSafeHandle batch)
{
return NativeApi.llama_decode(this, batch.Batch);
return NativeApi.llama_decode(this, batch.NativeBatch);
}
#region state

View File

@ -84,14 +84,14 @@ namespace LLama.Native
/// adapter. Can be NULL to use the current loaded model.</param>
/// <param name="threads"></param>
/// <exception cref="RuntimeError"></exception>
public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, uint? threads = null)
public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, int? threads = null)
{
var err = NativeApi.llama_model_apply_lora_from_file(
this,
lora,
scale,
string.IsNullOrEmpty(modelBase) ? null : modelBase,
(int?)threads ?? -1
threads ?? Math.Max(1, Environment.ProcessorCount / 2)
);
if (err != 0)