diff --git a/LLama/Native/LLamaKvCacheView.cs b/LLama/Native/LLamaKvCacheView.cs index 86169c60..36379bfd 100644 --- a/LLama/Native/LLamaKvCacheView.cs +++ b/LLama/Native/LLamaKvCacheView.cs @@ -1,72 +1,58 @@ -using System; +using System; using System.Runtime.InteropServices; namespace LLama.Native; -/// -/// Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell) -/// -[StructLayout(LayoutKind.Sequential)] -public struct LLamaKvCacheViewCell -{ - /// - /// The position for this cell. Takes KV cache shifts into account. - /// May be negative if the cell is not populated. - /// - public LLamaPos pos; -} - -/// -/// An updateable view of the KV cache (llama_kv_cache_view) -/// -[StructLayout(LayoutKind.Sequential)] -public unsafe struct LLamaKvCacheView -{ - // Number of KV cache cells. This will be the same as the context size. - int n_cells; - - // Maximum number of sequences that can exist in a cell. It's not an error - // if there are more sequences in a cell than this value, however they will - // not be visible in the view cells_sequences. - int n_seq_max; - - // Number of tokens in the cache. For example, if there are two populated - // cells, the first with 1 sequence id in it and the second with 2 sequence - // ids then you'll have 3 tokens. - int token_count; - - // Number of populated cache cells. - int used_cells; - - // Maximum contiguous empty slots in the cache. - int max_contiguous; - - // Index to the start of the max_contiguous slot range. Can be negative - // when cache is full. - int max_contiguous_idx; - - // Information for an individual cell. - LLamaKvCacheViewCell* cells; - - // The sequences for each cell. There will be n_seq_max items per cell. - LLamaSeqId* cells_sequences; -} - /// /// A safe handle for a LLamaKvCacheView /// -public class LLamaKvCacheViewSafeHandle +public sealed class LLamaKvCacheViewSafeHandle : SafeLLamaHandleBase { private readonly SafeLLamaContextHandle _ctx; - private LLamaKvCacheView _view; + private NativeLLamaKvCacheView _view; + + /// + /// Number of KV cache cells. This will be the same as the context size. + /// + public int CellCount => GetNativeView().n_cells; + + /// + /// Get the total number of tokens in the KV cache. + /// + /// For example, if there are two populated + /// cells, the first with 1 sequence id in it and the second with 2 sequence + /// ids then you'll have 3 tokens. + /// + public int TokenCount => GetNativeView().token_count; + + /// + /// Maximum number of sequences visible for a cell. There may be more sequences than this + /// in reality, this is simply the maximum number this view can see. + /// + public int MaxSequenceCount => GetNativeView().n_seq_max; + + /// + /// Number of populated cache cells + /// + public int UsedCellCount => GetNativeView().used_cells; + + /// + /// Maximum contiguous empty slots in the cache. + /// + public int MaxContiguous => GetNativeView().max_contiguous; + + /// + /// Index to the start of the MaxContiguous slot range. Can be negative when cache is full. + /// + public int MaxContiguousIdx => GetNativeView().max_contiguous; /// /// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed /// /// /// - public LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, LLamaKvCacheView view) + private LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, NativeLLamaKvCacheView view) : base((IntPtr)1, true) { _ctx = ctx; @@ -81,76 +67,176 @@ public class LLamaKvCacheViewSafeHandle /// public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences) { - var result = NativeApi.llama_kv_cache_view_init(ctx, maxSequences); - return new LLamaKvCacheViewSafeHandle(ctx, result); + // Allocate the view + var view = llama_kv_cache_view_init(ctx, maxSequences); + var handle = new LLamaKvCacheViewSafeHandle(ctx, view); + + // Update the view so it has valid data after allocation. + handle.Update(); + + return handle; } /// protected override bool ReleaseHandle() { - NativeApi.llama_kv_cache_view_free(ref _view); + llama_kv_cache_view_free(ref _view); SetHandle(IntPtr.Zero); return true; } /// - /// Update this view + /// Read the current KV cache state into this view. /// public void Update() { - NativeApi.llama_kv_cache_view_update(_ctx, ref _view); + llama_kv_cache_view_update(_ctx, ref _view); } /// /// Get the raw KV cache view /// /// - public ref LLamaKvCacheView GetView() + private ref NativeLLamaKvCacheView GetNativeView() { + if (IsClosed) + throw new ObjectDisposedException("Cannot access LLamaKvCacheViewSafeHandle after is has been disposed"); + return ref _view; } -} -public static partial class NativeApi -{ + /// + /// Get the cell at the given index + /// + /// The index of the cell [0, CellCount) + /// Data about the cell at the given index + /// Thrown if index is out of range (0 <= index < CellCount) + public LLamaPos GetCell(int index) + { + var view = GetNativeView(); + + if (index < 0) + throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0"); + if (index >= view.n_cells) + throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount"); + + unsafe + { + return view.cells[index].pos; + } + } + + /// + /// Get all of the sequences assigned to the cell at the given index. This will contain entries + /// sequences even if the cell actually has more than that many sequences, allocate a new view with a larger maxSequences parameter + /// if necessary. Invalid sequences will be negative values. + /// + /// The index of the cell [0, CellCount) + /// A span containing the sequences assigned to this cell + /// Thrown if index is out of range (0 <= index < CellCount) + public Span GetCellSequences(int index) + { + var view = GetNativeView(); + + if (index < 0) + throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0"); + if (index >= view.n_cells) + throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount"); + + unsafe + { + return new Span(&view.cells_sequences[index * view.n_seq_max], view.n_seq_max); + } + } + + #region native API /// /// Create an empty KV cache view. (use only for debugging purposes) /// /// /// /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max); - + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern NativeLLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max); + /// /// Free a KV cache view. (use only for debugging purposes) /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_kv_cache_view_free(ref LLamaKvCacheView view); - + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern void llama_kv_cache_view_free(ref NativeLLamaKvCacheView view); + /// /// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) /// /// /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref LLamaKvCacheView view); - + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref NativeLLamaKvCacheView view); + /// - /// Returns the number of tokens in the KV cache (slow, use only for debug) - /// If a KV cell has multiple sequences assigned to it, it will be counted multiple times + /// Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell) /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_get_kv_cache_token_count(SafeLLamaContextHandle ctx); - + [StructLayout(LayoutKind.Sequential)] + private struct NativeLLamaKvCacheViewCell + { + /// + /// The position for this cell. Takes KV cache shifts into account. + /// May be negative if the cell is not populated. + /// + public LLamaPos pos; + } + /// - /// Returns the number of used KV cells (i.e. have at least one sequence assigned to them) + /// An updateable view of the KV cache (llama_kv_cache_view) /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_get_kv_cache_used_cells(SafeLLamaContextHandle ctx); + [StructLayout(LayoutKind.Sequential)] + private unsafe struct NativeLLamaKvCacheView + { + /// + /// Number of KV cache cells. This will be the same as the context size. + /// + public int n_cells; + + /// + /// Maximum number of sequences that can exist in a cell. It's not an error + /// if there are more sequences in a cell than this value, however they will + /// not be visible in the view cells_sequences. + /// + public int n_seq_max; + + /// + /// Number of tokens in the cache. For example, if there are two populated + /// cells, the first with 1 sequence id in it and the second with 2 sequence + /// ids then you'll have 3 tokens. + /// + public int token_count; + + /// + /// Number of populated cache cells. + /// + public int used_cells; + + /// + /// Maximum contiguous empty slots in the cache. + /// + public int max_contiguous; + + /// + /// Index to the start of the max_contiguous slot range. Can be negative + /// when cache is full. + /// + public int max_contiguous_idx; + + /// + /// Information for an individual cell. + /// + public NativeLLamaKvCacheViewCell* cells; + + /// + /// The sequences for each cell. There will be n_seq_max items per cell. + /// + public LLamaSeqId* cells_sequences; + } + #endregion } \ No newline at end of file diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 708cdacc..9301198e 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -263,6 +263,23 @@ namespace LLama.Native { NativeLogConfig.llama_log_set(logCallback); } + + /// + /// Returns the number of tokens in the KV cache (slow, use only for debug) + /// If a KV cell has multiple sequences assigned to it, it will be counted multiple times + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern int llama_get_kv_cache_token_count(SafeLLamaContextHandle ctx); + + /// + /// Returns the number of used KV cells (i.e. have at least one sequence assigned to them) + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern int llama_get_kv_cache_used_cells(SafeLLamaContextHandle ctx); /// /// Clear the KV cache