Merge pull request #447 from martindevans/grow_nseqmax_batch
LLamaBatch Grow n_seq_max automatically
This commit is contained in:
commit
0074320a31
|
@ -52,7 +52,7 @@ public class BatchedDecoding
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
var batch = new LLamaBatch(1);
|
var batch = new LLamaBatch();
|
||||||
|
|
||||||
// evaluate the initial prompt
|
// evaluate the initial prompt
|
||||||
for (var i = 0; i < prompt_tokens.Length; i++)
|
for (var i = 0; i < prompt_tokens.Length; i++)
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
using System;
|
using System;
|
||||||
using System.Collections.Generic;
|
|
||||||
|
|
||||||
namespace LLama.Native;
|
namespace LLama.Native;
|
||||||
|
|
||||||
|
@ -35,11 +34,11 @@ public class LLamaBatch
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Create a new batch for submitting inputs to llama.cpp
|
/// Create a new batch for submitting inputs to llama.cpp
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="n_seq_max">Max number of sequences a token can be assigned to</param>
|
public LLamaBatch()
|
||||||
public LLamaBatch(int n_seq_max)
|
|
||||||
{
|
{
|
||||||
// The number of tokens can be grown later, start off with a reasonable guess.
|
// These can both be grown later, start off with reasonable numbers.
|
||||||
const int n_tokens = 64;
|
const int n_tokens = 128;
|
||||||
|
const int n_seq_max = 1;
|
||||||
|
|
||||||
MaxSequences = n_seq_max;
|
MaxSequences = n_seq_max;
|
||||||
TokenCapacity = n_tokens;
|
TokenCapacity = n_tokens;
|
||||||
|
@ -56,7 +55,7 @@ public class LLamaBatch
|
||||||
_sequenceIds[i] = new LLamaSeqId[MaxSequences];
|
_sequenceIds[i] = new LLamaSeqId[MaxSequences];
|
||||||
}
|
}
|
||||||
|
|
||||||
private void Grow()
|
private void GrowTokenCapacity()
|
||||||
{
|
{
|
||||||
var n_tokens = TokenCount * 2;
|
var n_tokens = TokenCount * 2;
|
||||||
TokenCapacity = n_tokens;
|
TokenCapacity = n_tokens;
|
||||||
|
@ -78,6 +77,15 @@ public class LLamaBatch
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void GrowMaxSequences(int atLeast)
|
||||||
|
{
|
||||||
|
var n_seq = Math.Max(MaxSequences * 2, atLeast);
|
||||||
|
MaxSequences = n_seq;
|
||||||
|
|
||||||
|
for (var i = 0; i < _sequenceIds.Length; i++)
|
||||||
|
Array.Resize(ref _sequenceIds[i], MaxSequences);
|
||||||
|
}
|
||||||
|
|
||||||
internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
|
internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
|
||||||
{
|
{
|
||||||
// This group holds all of the memory pins
|
// This group holds all of the memory pins
|
||||||
|
@ -120,7 +128,9 @@ public class LLamaBatch
|
||||||
public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
|
public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
|
||||||
{
|
{
|
||||||
if (TokenCount == TokenCapacity)
|
if (TokenCount == TokenCapacity)
|
||||||
Grow();
|
GrowTokenCapacity();
|
||||||
|
if (sequences.Length > MaxSequences)
|
||||||
|
GrowMaxSequences(sequences.Length);
|
||||||
|
|
||||||
_tokens[TokenCount] = token;
|
_tokens[TokenCount] = token;
|
||||||
_positions[TokenCount] = pos;
|
_positions[TokenCount] = pos;
|
||||||
|
|
Loading…
Reference in New Issue