diff --git a/LLama/Native/LLamaKvCacheView.cs b/LLama/Native/LLamaKvCacheView.cs
index 86169c60..36379bfd 100644
--- a/LLama/Native/LLamaKvCacheView.cs
+++ b/LLama/Native/LLamaKvCacheView.cs
@@ -1,72 +1,58 @@
-using System;
+using System;
using System.Runtime.InteropServices;
namespace LLama.Native;
-///
-/// Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell)
-///
-[StructLayout(LayoutKind.Sequential)]
-public struct LLamaKvCacheViewCell
-{
- ///
- /// The position for this cell. Takes KV cache shifts into account.
- /// May be negative if the cell is not populated.
- ///
- public LLamaPos pos;
-}
-
-///
-/// An updateable view of the KV cache (llama_kv_cache_view)
-///
-[StructLayout(LayoutKind.Sequential)]
-public unsafe struct LLamaKvCacheView
-{
- // Number of KV cache cells. This will be the same as the context size.
- int n_cells;
-
- // Maximum number of sequences that can exist in a cell. It's not an error
- // if there are more sequences in a cell than this value, however they will
- // not be visible in the view cells_sequences.
- int n_seq_max;
-
- // Number of tokens in the cache. For example, if there are two populated
- // cells, the first with 1 sequence id in it and the second with 2 sequence
- // ids then you'll have 3 tokens.
- int token_count;
-
- // Number of populated cache cells.
- int used_cells;
-
- // Maximum contiguous empty slots in the cache.
- int max_contiguous;
-
- // Index to the start of the max_contiguous slot range. Can be negative
- // when cache is full.
- int max_contiguous_idx;
-
- // Information for an individual cell.
- LLamaKvCacheViewCell* cells;
-
- // The sequences for each cell. There will be n_seq_max items per cell.
- LLamaSeqId* cells_sequences;
-}
-
///
/// A safe handle for a LLamaKvCacheView
///
-public class LLamaKvCacheViewSafeHandle
+public sealed class LLamaKvCacheViewSafeHandle
: SafeLLamaHandleBase
{
private readonly SafeLLamaContextHandle _ctx;
- private LLamaKvCacheView _view;
+ private NativeLLamaKvCacheView _view;
+
+ ///
+ /// Number of KV cache cells. This will be the same as the context size.
+ ///
+ public int CellCount => GetNativeView().n_cells;
+
+ ///
+ /// Get the total number of tokens in the KV cache.
+ ///
+ /// For example, if there are two populated
+ /// cells, the first with 1 sequence id in it and the second with 2 sequence
+ /// ids then you'll have 3 tokens.
+ ///
+ public int TokenCount => GetNativeView().token_count;
+
+ ///
+ /// Maximum number of sequences visible for a cell. There may be more sequences than this
+ /// in reality, this is simply the maximum number this view can see.
+ ///
+ public int MaxSequenceCount => GetNativeView().n_seq_max;
+
+ ///
+ /// Number of populated cache cells
+ ///
+ public int UsedCellCount => GetNativeView().used_cells;
+
+ ///
+ /// Maximum contiguous empty slots in the cache.
+ ///
+ public int MaxContiguous => GetNativeView().max_contiguous;
+
+ ///
+ /// Index to the start of the MaxContiguous slot range. Can be negative when cache is full.
+ ///
+ public int MaxContiguousIdx => GetNativeView().max_contiguous;
///
/// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed
///
///
///
- public LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, LLamaKvCacheView view)
+ private LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, NativeLLamaKvCacheView view)
: base((IntPtr)1, true)
{
_ctx = ctx;
@@ -81,76 +67,176 @@ public class LLamaKvCacheViewSafeHandle
///
public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences)
{
- var result = NativeApi.llama_kv_cache_view_init(ctx, maxSequences);
- return new LLamaKvCacheViewSafeHandle(ctx, result);
+ // Allocate the view
+ var view = llama_kv_cache_view_init(ctx, maxSequences);
+ var handle = new LLamaKvCacheViewSafeHandle(ctx, view);
+
+ // Update the view so it has valid data after allocation.
+ handle.Update();
+
+ return handle;
}
///
protected override bool ReleaseHandle()
{
- NativeApi.llama_kv_cache_view_free(ref _view);
+ llama_kv_cache_view_free(ref _view);
SetHandle(IntPtr.Zero);
return true;
}
///
- /// Update this view
+ /// Read the current KV cache state into this view.
///
public void Update()
{
- NativeApi.llama_kv_cache_view_update(_ctx, ref _view);
+ llama_kv_cache_view_update(_ctx, ref _view);
}
///
/// Get the raw KV cache view
///
///
- public ref LLamaKvCacheView GetView()
+ private ref NativeLLamaKvCacheView GetNativeView()
{
+ if (IsClosed)
+ throw new ObjectDisposedException("Cannot access LLamaKvCacheViewSafeHandle after is has been disposed");
+
return ref _view;
}
-}
-public static partial class NativeApi
-{
+ ///
+ /// Get the cell at the given index
+ ///
+ /// The index of the cell [0, CellCount)
+ /// Data about the cell at the given index
+ /// Thrown if index is out of range (0 <= index < CellCount)
+ public LLamaPos GetCell(int index)
+ {
+ var view = GetNativeView();
+
+ if (index < 0)
+ throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0");
+ if (index >= view.n_cells)
+ throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount");
+
+ unsafe
+ {
+ return view.cells[index].pos;
+ }
+ }
+
+ ///
+ /// Get all of the sequences assigned to the cell at the given index. This will contain entries
+ /// sequences even if the cell actually has more than that many sequences, allocate a new view with a larger maxSequences parameter
+ /// if necessary. Invalid sequences will be negative values.
+ ///
+ /// The index of the cell [0, CellCount)
+ /// A span containing the sequences assigned to this cell
+ /// Thrown if index is out of range (0 <= index < CellCount)
+ public Span GetCellSequences(int index)
+ {
+ var view = GetNativeView();
+
+ if (index < 0)
+ throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0");
+ if (index >= view.n_cells)
+ throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount");
+
+ unsafe
+ {
+ return new Span(&view.cells_sequences[index * view.n_seq_max], view.n_seq_max);
+ }
+ }
+
+ #region native API
///
/// Create an empty KV cache view. (use only for debugging purposes)
///
///
///
///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern LLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max);
-
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern NativeLLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max);
+
///
/// Free a KV cache view. (use only for debugging purposes)
///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_kv_cache_view_free(ref LLamaKvCacheView view);
-
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern void llama_kv_cache_view_free(ref NativeLLamaKvCacheView view);
+
///
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
///
///
///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref LLamaKvCacheView view);
-
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref NativeLLamaKvCacheView view);
+
///
- /// Returns the number of tokens in the KV cache (slow, use only for debug)
- /// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
+ /// Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell)
///
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_get_kv_cache_token_count(SafeLLamaContextHandle ctx);
-
+ [StructLayout(LayoutKind.Sequential)]
+ private struct NativeLLamaKvCacheViewCell
+ {
+ ///
+ /// The position for this cell. Takes KV cache shifts into account.
+ /// May be negative if the cell is not populated.
+ ///
+ public LLamaPos pos;
+ }
+
///
- /// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
+ /// An updateable view of the KV cache (llama_kv_cache_view)
///
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_get_kv_cache_used_cells(SafeLLamaContextHandle ctx);
+ [StructLayout(LayoutKind.Sequential)]
+ private unsafe struct NativeLLamaKvCacheView
+ {
+ ///
+ /// Number of KV cache cells. This will be the same as the context size.
+ ///
+ public int n_cells;
+
+ ///
+ /// Maximum number of sequences that can exist in a cell. It's not an error
+ /// if there are more sequences in a cell than this value, however they will
+ /// not be visible in the view cells_sequences.
+ ///
+ public int n_seq_max;
+
+ ///
+ /// Number of tokens in the cache. For example, if there are two populated
+ /// cells, the first with 1 sequence id in it and the second with 2 sequence
+ /// ids then you'll have 3 tokens.
+ ///
+ public int token_count;
+
+ ///
+ /// Number of populated cache cells.
+ ///
+ public int used_cells;
+
+ ///
+ /// Maximum contiguous empty slots in the cache.
+ ///
+ public int max_contiguous;
+
+ ///
+ /// Index to the start of the max_contiguous slot range. Can be negative
+ /// when cache is full.
+ ///
+ public int max_contiguous_idx;
+
+ ///
+ /// Information for an individual cell.
+ ///
+ public NativeLLamaKvCacheViewCell* cells;
+
+ ///
+ /// The sequences for each cell. There will be n_seq_max items per cell.
+ ///
+ public LLamaSeqId* cells_sequences;
+ }
+ #endregion
}
\ No newline at end of file
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 708cdacc..9301198e 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -263,6 +263,23 @@ namespace LLama.Native
{
NativeLogConfig.llama_log_set(logCallback);
}
+
+ ///
+ /// Returns the number of tokens in the KV cache (slow, use only for debug)
+ /// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
+ ///
+ ///
+ ///
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern int llama_get_kv_cache_token_count(SafeLLamaContextHandle ctx);
+
+ ///
+ /// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
+ ///
+ ///
+ ///
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern int llama_get_kv_cache_used_cells(SafeLLamaContextHandle ctx);
///
/// Clear the KV cache