141 lines
4.1 KiB
C#
141 lines
4.1 KiB
C#
using System;
|
|
using System.Threading;
|
|
using System.Threading.Tasks;
|
|
using LLama.Abstractions;
|
|
using LLama.Native;
|
|
|
|
namespace LLama.Batched;
|
|
|
|
/// <summary>
|
|
/// A batched executor that can infer multiple separate "conversations" simultaneously.
|
|
/// </summary>
|
|
public sealed class BatchedExecutor
|
|
: IDisposable
|
|
{
|
|
private int _nextSequenceId;
|
|
|
|
internal LLamaBatch Batch { get; }
|
|
|
|
/// <summary>
|
|
/// Epoch is incremented every time Infer is called. Conversations can use this to keep track of
|
|
/// whether they're waiting for inference, or can be sampled.
|
|
/// </summary>
|
|
internal ulong Epoch { get; private set; }
|
|
|
|
/// <summary>
|
|
/// The <see cref="LLamaContext"/> this executor is using
|
|
/// </summary>
|
|
public LLamaContext Context { get; }
|
|
|
|
/// <summary>
|
|
/// The <see cref="LLamaWeights"/> this executor is using
|
|
/// </summary>
|
|
public LLamaWeights Model { get; }
|
|
|
|
/// <summary>
|
|
/// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
|
|
/// </summary>
|
|
public int BatchedTokenCount => Batch.TokenCount;
|
|
|
|
/// <summary>
|
|
/// Check if this executor has been disposed.
|
|
/// </summary>
|
|
public bool IsDisposed { get; private set; }
|
|
|
|
/// <summary>
|
|
/// Create a new batched executor
|
|
/// </summary>
|
|
/// <param name="model">The model to use</param>
|
|
/// <param name="contextParams">Parameters to create a new context</param>
|
|
public BatchedExecutor(LLamaWeights model, IContextParams contextParams)
|
|
{
|
|
Model = model;
|
|
Batch = new LLamaBatch();
|
|
Context = model.CreateContext(contextParams);
|
|
Epoch = 1;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Start a new <see cref="Conversation"/>
|
|
/// </summary>
|
|
/// <returns></returns>
|
|
public Conversation Create()
|
|
{
|
|
if (IsDisposed)
|
|
throw new ObjectDisposedException(nameof(BatchedExecutor));
|
|
|
|
return new Conversation(this, GetNextSequenceId());
|
|
}
|
|
|
|
/// <summary>
|
|
/// Load a conversation that was previously saved to a file. Once loaded the conversation will
|
|
/// need to be prompted.
|
|
/// </summary>
|
|
/// <param name="filepath"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="ObjectDisposedException"></exception>
|
|
public Conversation Load(string filepath)
|
|
{
|
|
if (IsDisposed)
|
|
throw new ObjectDisposedException(nameof(BatchedExecutor));
|
|
|
|
var conversation = Create();
|
|
conversation.Load(filepath);
|
|
return conversation;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Load a conversation that was previously saved into memory. Once loaded the conversation will need to be prompted.
|
|
/// </summary>
|
|
/// <param name="state"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="ObjectDisposedException"></exception>
|
|
public Conversation Load(Conversation.State state)
|
|
{
|
|
if (IsDisposed)
|
|
throw new ObjectDisposedException(nameof(BatchedExecutor));
|
|
|
|
var conversation = Create();
|
|
conversation.Load(state);
|
|
return conversation;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Run inference for all conversations in the batch which have pending tokens.
|
|
///
|
|
/// If the result is `NoKvSlot` then there is not enough memory for inference, try disposing some conversation
|
|
/// threads and running inference again.
|
|
/// </summary>
|
|
public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
|
|
{
|
|
if (IsDisposed)
|
|
throw new ObjectDisposedException(nameof(BatchedExecutor));
|
|
|
|
var status = await Context.DecodeAsync(Batch, cancellation);
|
|
|
|
// Only clear the batch if the result was ok. leaving all this state in place means that "Infer" can
|
|
// be called again after a warning (e.g. NoKvSlot).
|
|
if (status == DecodeResult.Ok)
|
|
{
|
|
Epoch++;
|
|
Batch.Clear();
|
|
}
|
|
|
|
return status;
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public void Dispose()
|
|
{
|
|
if (IsDisposed)
|
|
return;
|
|
IsDisposed = true;
|
|
|
|
Context.Dispose();
|
|
}
|
|
|
|
internal LLamaSeqId GetNextSequenceId()
|
|
{
|
|
return checked((LLamaSeqId)_nextSequenceId++);
|
|
}
|
|
} |