commit
66d6b00b49
|
@ -7,6 +7,9 @@ using System.Text;
|
|||
using System.IO;
|
||||
using System.IO.MemoryMappedFiles;
|
||||
using LLama.Common;
|
||||
using System.Runtime.InteropServices;
|
||||
using LLama.Extensions;
|
||||
using Microsoft.Win32.SafeHandles;
|
||||
|
||||
namespace LLama
|
||||
{
|
||||
|
@ -117,6 +120,7 @@ namespace LLama
|
|||
/// Get the state data as a byte array.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
[Obsolete("Use `GetState` instead, this supports larger states (over 2GB)")]
|
||||
public byte[] GetStateData()
|
||||
{
|
||||
var stateSize = NativeApi.llama_get_state_size(_ctx);
|
||||
|
@ -125,6 +129,44 @@ namespace LLama
|
|||
return stateMemory;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the state data as an opaque handle
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public State GetState()
|
||||
{
|
||||
var stateSize = NativeApi.llama_get_state_size(_ctx);
|
||||
|
||||
unsafe
|
||||
{
|
||||
var bigMemory = Marshal.AllocHGlobal((nint)stateSize);
|
||||
var smallMemory = IntPtr.Zero;
|
||||
try
|
||||
{
|
||||
// Copy the state data into "big memory", discover the actual size required
|
||||
var actualSize = NativeApi.llama_copy_state_data(_ctx, (byte*)bigMemory);
|
||||
|
||||
// Allocate a smaller buffer
|
||||
smallMemory = Marshal.AllocHGlobal((nint)actualSize);
|
||||
|
||||
// Copy into the smaller buffer and free the large one to save excess memory usage
|
||||
Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize);
|
||||
Marshal.FreeHGlobal(bigMemory);
|
||||
bigMemory = IntPtr.Zero;
|
||||
|
||||
return new State(smallMemory);
|
||||
}
|
||||
catch
|
||||
{
|
||||
if (bigMemory != IntPtr.Zero)
|
||||
Marshal.FreeHGlobal(bigMemory);
|
||||
if (smallMemory != IntPtr.Zero)
|
||||
Marshal.FreeHGlobal(smallMemory);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load the state from specified path.
|
||||
/// </summary>
|
||||
|
@ -161,6 +203,19 @@ namespace LLama
|
|||
NativeApi.llama_set_state_data(_ctx, stateData);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load the state from memory.
|
||||
/// </summary>
|
||||
/// <param name="state"></param>
|
||||
/// <exception cref="RuntimeError"></exception>
|
||||
public void LoadState(State state)
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
NativeApi.llama_set_state_data(_ctx, (byte*)state.DangerousGetHandle().ToPointer());
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Perform the sampling. Please don't use it unless you fully know what it does.
|
||||
/// </summary>
|
||||
|
@ -304,12 +359,30 @@ namespace LLama
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
///
|
||||
/// </summary>
|
||||
public void Dispose()
|
||||
/// <inheritdoc />
|
||||
public virtual void Dispose()
|
||||
{
|
||||
_ctx.Dispose();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The state of this model, which can be reloaded later
|
||||
/// </summary>
|
||||
public class State
|
||||
: SafeHandleZeroOrMinusOneIsInvalid
|
||||
{
|
||||
internal State(IntPtr memory)
|
||||
: base(true)
|
||||
{
|
||||
SetHandle(memory);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override bool ReleaseHandle()
|
||||
{
|
||||
Marshal.FreeHGlobal(handle);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,10 +3,8 @@ using LLama.Common;
|
|||
using LLama.Native;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Linq;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
|
||||
namespace LLama
|
||||
|
@ -19,7 +17,7 @@ namespace LLama
|
|||
public class StatelessExecutor : ILLamaExecutor
|
||||
{
|
||||
private LLamaModel _model;
|
||||
private byte[] _originalState;
|
||||
private LLamaModel.State _originalState;
|
||||
/// <summary>
|
||||
/// The mode used by the executor when running the inference.
|
||||
/// </summary>
|
||||
|
@ -33,7 +31,7 @@ namespace LLama
|
|||
_model = model;
|
||||
var tokens = model.Tokenize(" ", true);
|
||||
Utils.Eval(_model.NativeHandle, tokens.ToArray(), 0, tokens.Count(), 0, _model.Params.Threads);
|
||||
_originalState = model.GetStateData();
|
||||
_originalState = model.GetState();
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
|
|
|
@ -13,7 +13,7 @@ namespace LLama
|
|||
/// <summary>
|
||||
/// The initial state of the model
|
||||
/// </summary>
|
||||
public byte[] OriginalState { get; set; }
|
||||
public State OriginalState { get; set; }
|
||||
/// <summary>
|
||||
///
|
||||
/// </summary>
|
||||
|
@ -21,7 +21,7 @@ namespace LLama
|
|||
/// <param name="encoding"></param>
|
||||
public ResettableLLamaModel(ModelParams Params, string encoding = "UTF-8") : base(Params, encoding)
|
||||
{
|
||||
OriginalState = GetStateData();
|
||||
OriginalState = GetState();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -31,5 +31,13 @@ namespace LLama
|
|||
{
|
||||
LoadState(OriginalState);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void Dispose()
|
||||
{
|
||||
OriginalState.Dispose();
|
||||
|
||||
base.Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue