Added a safe handle for LLamaKvCacheView
This commit is contained in:
parent
439d14a061
commit
bab6b65b61
|
@ -7,7 +7,11 @@ Console.WriteLine(" __ __ ____ _
|
|||
|
||||
Console.WriteLine("======================================================================================================");
|
||||
|
||||
NativeLibraryConfig.Instance.WithCuda().WithLogs();
|
||||
NativeLibraryConfig
|
||||
.Instance
|
||||
.WithCuda()
|
||||
.WithLogs()
|
||||
.WithAvx(NativeLibraryConfig.AvxLevel.Avx512);
|
||||
|
||||
NativeApi.llama_empty_call();
|
||||
Console.WriteLine();
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
using System.Runtime.InteropServices;
|
||||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace LLama.Native;
|
||||
|
||||
|
@ -18,7 +19,6 @@ public struct LLamaKvCacheViewCell
|
|||
/// <summary>
|
||||
/// An updateable view of the KV cache (llama_kv_cache_view)
|
||||
/// </summary>
|
||||
//todo: rewrite to safe handle?
|
||||
[StructLayout(LayoutKind.Sequential)]
|
||||
public unsafe struct LLamaKvCacheView
|
||||
{
|
||||
|
@ -52,6 +52,84 @@ public unsafe struct LLamaKvCacheView
|
|||
LLamaSeqId* cells_sequences;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A safe handle for a LLamaKvCacheView
|
||||
/// </summary>
|
||||
public class LLamaKvCacheViewSafeHandle
|
||||
: SafeLLamaHandleBase
|
||||
{
|
||||
private readonly SafeLLamaContextHandle _ctx;
|
||||
private LLamaKvCacheView _view;
|
||||
|
||||
/// <summary>
|
||||
/// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="view"></param>
|
||||
public LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, LLamaKvCacheView view)
|
||||
: base(IntPtr.MaxValue, true)
|
||||
{
|
||||
_ctx = ctx;
|
||||
_view = view;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Allocate a new llama_kv_cache_view_free
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="maxSequences">The maximum number of sequences visible in this view per cell</param>
|
||||
/// <returns></returns>
|
||||
public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences)
|
||||
{
|
||||
var result = NativeApi.llama_kv_cache_view_init(ctx, maxSequences);
|
||||
return new LLamaKvCacheViewSafeHandle(ctx, result);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override bool ReleaseHandle()
|
||||
{
|
||||
NativeApi.llama_kv_cache_view_free(ref _view);
|
||||
SetHandle(IntPtr.Zero);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Update this view
|
||||
/// </summary>
|
||||
public void Update()
|
||||
{
|
||||
NativeApi.llama_kv_cache_view_update(_ctx, ref _view);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Count the number of used cells in the KV cache
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public int CountCells()
|
||||
{
|
||||
return NativeApi.llama_get_kv_cache_used_cells(_ctx);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be countered multiple times
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public int CountTokens()
|
||||
{
|
||||
return NativeApi.llama_get_kv_cache_token_count(_ctx);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the raw KV cache view
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public ref LLamaKvCacheView GetView()
|
||||
{
|
||||
return ref _view;
|
||||
}
|
||||
}
|
||||
|
||||
partial class NativeApi
|
||||
{
|
||||
/// <summary>
|
||||
|
@ -66,9 +144,8 @@ partial class NativeApi
|
|||
/// <summary>
|
||||
/// Free a KV cache view. (use only for debugging purposes)
|
||||
/// </summary>
|
||||
/// <param name="view"></param>
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern unsafe void llama_kv_cache_view_free(LLamaKvCacheView* view);
|
||||
public static extern void llama_kv_cache_view_free(ref LLamaKvCacheView view);
|
||||
|
||||
/// <summary>
|
||||
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
||||
|
@ -76,7 +153,7 @@ partial class NativeApi
|
|||
/// <param name="ctx"></param>
|
||||
/// <param name="view"></param>
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern unsafe void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, LLamaKvCacheView* view);
|
||||
public static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref LLamaKvCacheView view);
|
||||
|
||||
/// <summary>
|
||||
/// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||
|
|
Loading…
Reference in New Issue