Merge pull request #447 from martindevans/grow_nseqmax_batch

LLamaBatch Grow n_seq_max automatically
This commit is contained in:
Martin Evans 2024-01-21 01:07:12 +00:00 committed by GitHub
commit 0074320a31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 8 deletions

View File

@ -52,7 +52,7 @@ public class BatchedDecoding
return; return;
} }
var batch = new LLamaBatch(1); var batch = new LLamaBatch();
// evaluate the initial prompt // evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++) for (var i = 0; i < prompt_tokens.Length; i++)

View File

@ -1,5 +1,4 @@
using System; using System;
using System.Collections.Generic;
namespace LLama.Native; namespace LLama.Native;
@ -35,11 +34,11 @@ public class LLamaBatch
/// <summary> /// <summary>
/// Create a new batch for submitting inputs to llama.cpp /// Create a new batch for submitting inputs to llama.cpp
/// </summary> /// </summary>
/// <param name="n_seq_max">Max number of sequences a token can be assigned to</param> public LLamaBatch()
public LLamaBatch(int n_seq_max)
{ {
// The number of tokens can be grown later, start off with a reasonable guess. // These can both be grown later, start off with reasonable numbers.
const int n_tokens = 64; const int n_tokens = 128;
const int n_seq_max = 1;
MaxSequences = n_seq_max; MaxSequences = n_seq_max;
TokenCapacity = n_tokens; TokenCapacity = n_tokens;
@ -56,7 +55,7 @@ public class LLamaBatch
_sequenceIds[i] = new LLamaSeqId[MaxSequences]; _sequenceIds[i] = new LLamaSeqId[MaxSequences];
} }
private void Grow() private void GrowTokenCapacity()
{ {
var n_tokens = TokenCount * 2; var n_tokens = TokenCount * 2;
TokenCapacity = n_tokens; TokenCapacity = n_tokens;
@ -78,6 +77,15 @@ public class LLamaBatch
} }
} }
private void GrowMaxSequences(int atLeast)
{
var n_seq = Math.Max(MaxSequences * 2, atLeast);
MaxSequences = n_seq;
for (var i = 0; i < _sequenceIds.Length; i++)
Array.Resize(ref _sequenceIds[i], MaxSequences);
}
internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
{ {
// This group holds all of the memory pins // This group holds all of the memory pins
@ -120,7 +128,9 @@ public class LLamaBatch
public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
{ {
if (TokenCount == TokenCapacity) if (TokenCount == TokenCapacity)
Grow(); GrowTokenCapacity();
if (sequences.Length > MaxSequences)
GrowMaxSequences(sequences.Length);
_tokens[TokenCount] = token; _tokens[TokenCount] = token;
_positions[TokenCount] = pos; _positions[TokenCount] = pos;