Merge pull request #721 from martindevans/kv_cache_view

Make `LLamaKvCacheView` Safe
This commit is contained in:
Martin Evans 2024-05-10 15:19:36 +01:00 committed by GitHub
commit 3b0b2ab224
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 186 additions and 83 deletions

View File

@ -1,72 +1,58 @@
using System; using System;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
namespace LLama.Native; namespace LLama.Native;
/// <summary>
/// Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell)
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public struct LLamaKvCacheViewCell
{
/// <summary>
/// The position for this cell. Takes KV cache shifts into account.
/// May be negative if the cell is not populated.
/// </summary>
public LLamaPos pos;
}
/// <summary>
/// An updateable view of the KV cache (llama_kv_cache_view)
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public unsafe struct LLamaKvCacheView
{
// Number of KV cache cells. This will be the same as the context size.
int n_cells;
// Maximum number of sequences that can exist in a cell. It's not an error
// if there are more sequences in a cell than this value, however they will
// not be visible in the view cells_sequences.
int n_seq_max;
// Number of tokens in the cache. For example, if there are two populated
// cells, the first with 1 sequence id in it and the second with 2 sequence
// ids then you'll have 3 tokens.
int token_count;
// Number of populated cache cells.
int used_cells;
// Maximum contiguous empty slots in the cache.
int max_contiguous;
// Index to the start of the max_contiguous slot range. Can be negative
// when cache is full.
int max_contiguous_idx;
// Information for an individual cell.
LLamaKvCacheViewCell* cells;
// The sequences for each cell. There will be n_seq_max items per cell.
LLamaSeqId* cells_sequences;
}
/// <summary> /// <summary>
/// A safe handle for a LLamaKvCacheView /// A safe handle for a LLamaKvCacheView
/// </summary> /// </summary>
public class LLamaKvCacheViewSafeHandle public sealed class LLamaKvCacheViewSafeHandle
: SafeLLamaHandleBase : SafeLLamaHandleBase
{ {
private readonly SafeLLamaContextHandle _ctx; private readonly SafeLLamaContextHandle _ctx;
private LLamaKvCacheView _view; private NativeLLamaKvCacheView _view;
/// <summary>
/// Number of KV cache cells. This will be the same as the context size.
/// </summary>
public int CellCount => GetNativeView().n_cells;
/// <summary>
/// Get the total number of tokens in the KV cache.
///
/// For example, if there are two populated
/// cells, the first with 1 sequence id in it and the second with 2 sequence
/// ids then you'll have 3 tokens.
/// </summary>
public int TokenCount => GetNativeView().token_count;
/// <summary>
/// Maximum number of sequences visible for a cell. There may be more sequences than this
/// in reality, this is simply the maximum number this view can see.
/// </summary>
public int MaxSequenceCount => GetNativeView().n_seq_max;
/// <summary>
/// Number of populated cache cells
/// </summary>
public int UsedCellCount => GetNativeView().used_cells;
/// <summary>
/// Maximum contiguous empty slots in the cache.
/// </summary>
public int MaxContiguous => GetNativeView().max_contiguous;
/// <summary>
/// Index to the start of the MaxContiguous slot range. Can be negative when cache is full.
/// </summary>
public int MaxContiguousIdx => GetNativeView().max_contiguous;
/// <summary> /// <summary>
/// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed /// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed
/// </summary> /// </summary>
/// <param name="ctx"></param> /// <param name="ctx"></param>
/// <param name="view"></param> /// <param name="view"></param>
public LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, LLamaKvCacheView view) private LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, NativeLLamaKvCacheView view)
: base((IntPtr)1, true) : base((IntPtr)1, true)
{ {
_ctx = ctx; _ctx = ctx;
@ -81,76 +67,176 @@ public class LLamaKvCacheViewSafeHandle
/// <returns></returns> /// <returns></returns>
public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences) public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences)
{ {
var result = NativeApi.llama_kv_cache_view_init(ctx, maxSequences); // Allocate the view
return new LLamaKvCacheViewSafeHandle(ctx, result); var view = llama_kv_cache_view_init(ctx, maxSequences);
var handle = new LLamaKvCacheViewSafeHandle(ctx, view);
// Update the view so it has valid data after allocation.
handle.Update();
return handle;
} }
/// <inheritdoc /> /// <inheritdoc />
protected override bool ReleaseHandle() protected override bool ReleaseHandle()
{ {
NativeApi.llama_kv_cache_view_free(ref _view); llama_kv_cache_view_free(ref _view);
SetHandle(IntPtr.Zero); SetHandle(IntPtr.Zero);
return true; return true;
} }
/// <summary> /// <summary>
/// Update this view /// Read the current KV cache state into this view.
/// </summary> /// </summary>
public void Update() public void Update()
{ {
NativeApi.llama_kv_cache_view_update(_ctx, ref _view); llama_kv_cache_view_update(_ctx, ref _view);
} }
/// <summary> /// <summary>
/// Get the raw KV cache view /// Get the raw KV cache view
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
public ref LLamaKvCacheView GetView() private ref NativeLLamaKvCacheView GetNativeView()
{ {
if (IsClosed)
throw new ObjectDisposedException("Cannot access LLamaKvCacheViewSafeHandle after is has been disposed");
return ref _view; return ref _view;
} }
}
public static partial class NativeApi /// <summary>
{ /// Get the cell at the given index
/// </summary>
/// <param name="index">The index of the cell [0, CellCount)</param>
/// <returns>Data about the cell at the given index</returns>
/// <exception cref="ArgumentOutOfRangeException">Thrown if index is out of range (0 &lt;= index &lt; CellCount)</exception>
public LLamaPos GetCell(int index)
{
var view = GetNativeView();
if (index < 0)
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0");
if (index >= view.n_cells)
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount");
unsafe
{
return view.cells[index].pos;
}
}
/// <summary>
/// Get all of the sequences assigned to the cell at the given index. This will contain <see cref="MaxSequenceCount"/> entries
/// sequences even if the cell actually has more than that many sequences, allocate a new view with a larger maxSequences parameter
/// if necessary. Invalid sequences will be negative values.
/// </summary>
/// <param name="index">The index of the cell [0, CellCount)</param>
/// <returns>A span containing the sequences assigned to this cell</returns>
/// <exception cref="ArgumentOutOfRangeException">Thrown if index is out of range (0 &lt;= index &lt; CellCount)</exception>
public Span<LLamaSeqId> GetCellSequences(int index)
{
var view = GetNativeView();
if (index < 0)
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0");
if (index >= view.n_cells)
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount");
unsafe
{
return new Span<LLamaSeqId>(&view.cells_sequences[index * view.n_seq_max], view.n_seq_max);
}
}
#region native API
/// <summary> /// <summary>
/// Create an empty KV cache view. (use only for debugging purposes) /// Create an empty KV cache view. (use only for debugging purposes)
/// </summary> /// </summary>
/// <param name="ctx"></param> /// <param name="ctx"></param>
/// <param name="n_seq_max"></param> /// <param name="n_seq_max"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max); private static extern NativeLLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max);
/// <summary> /// <summary>
/// Free a KV cache view. (use only for debugging purposes) /// Free a KV cache view. (use only for debugging purposes)
/// </summary> /// </summary>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_view_free(ref LLamaKvCacheView view); private static extern void llama_kv_cache_view_free(ref NativeLLamaKvCacheView view);
/// <summary> /// <summary>
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) /// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
/// </summary> /// </summary>
/// <param name="ctx"></param> /// <param name="ctx"></param>
/// <param name="view"></param> /// <param name="view"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref LLamaKvCacheView view); private static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref NativeLLamaKvCacheView view);
/// <summary> /// <summary>
/// Returns the number of tokens in the KV cache (slow, use only for debug) /// Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell)
/// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
/// </summary> /// </summary>
/// <param name="ctx"></param> [StructLayout(LayoutKind.Sequential)]
/// <returns></returns> private struct NativeLLamaKvCacheViewCell
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] {
public static extern int llama_get_kv_cache_token_count(SafeLLamaContextHandle ctx); /// <summary>
/// The position for this cell. Takes KV cache shifts into account.
/// May be negative if the cell is not populated.
/// </summary>
public LLamaPos pos;
}
/// <summary> /// <summary>
/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them) /// An updateable view of the KV cache (llama_kv_cache_view)
/// </summary> /// </summary>
/// <param name="ctx"></param> [StructLayout(LayoutKind.Sequential)]
/// <returns></returns> private unsafe struct NativeLLamaKvCacheView
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] {
public static extern int llama_get_kv_cache_used_cells(SafeLLamaContextHandle ctx); /// <summary>
/// Number of KV cache cells. This will be the same as the context size.
/// </summary>
public int n_cells;
/// <summary>
/// Maximum number of sequences that can exist in a cell. It's not an error
/// if there are more sequences in a cell than this value, however they will
/// not be visible in the view cells_sequences.
/// </summary>
public int n_seq_max;
/// <summary>
/// Number of tokens in the cache. For example, if there are two populated
/// cells, the first with 1 sequence id in it and the second with 2 sequence
/// ids then you'll have 3 tokens.
/// </summary>
public int token_count;
/// <summary>
/// Number of populated cache cells.
/// </summary>
public int used_cells;
/// <summary>
/// Maximum contiguous empty slots in the cache.
/// </summary>
public int max_contiguous;
/// <summary>
/// Index to the start of the max_contiguous slot range. Can be negative
/// when cache is full.
/// </summary>
public int max_contiguous_idx;
/// <summary>
/// Information for an individual cell.
/// </summary>
public NativeLLamaKvCacheViewCell* cells;
/// <summary>
/// The sequences for each cell. There will be n_seq_max items per cell.
/// </summary>
public LLamaSeqId* cells_sequences;
}
#endregion
} }

View File

@ -264,6 +264,23 @@ namespace LLama.Native
NativeLogConfig.llama_log_set(logCallback); NativeLogConfig.llama_log_set(logCallback);
} }
/// <summary>
/// Returns the number of tokens in the KV cache (slow, use only for debug)
/// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_get_kv_cache_token_count(SafeLLamaContextHandle ctx);
/// <summary>
/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_get_kv_cache_used_cells(SafeLLamaContextHandle ctx);
/// <summary> /// <summary>
/// Clear the KV cache /// Clear the KV cache
/// </summary> /// </summary>