Merge pull request #57 from martindevans/larger_states

Larger states
This commit is contained in:
Rinne 2023-07-24 23:10:39 +08:00 committed by GitHub
commit 66d6b00b49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 89 additions and 10 deletions

View File

@ -7,6 +7,9 @@ using System.Text;
using System.IO;
using System.IO.MemoryMappedFiles;
using LLama.Common;
using System.Runtime.InteropServices;
using LLama.Extensions;
using Microsoft.Win32.SafeHandles;
namespace LLama
{
@ -117,6 +120,7 @@ namespace LLama
/// Get the state data as a byte array.
/// </summary>
/// <returns></returns>
[Obsolete("Use `GetState` instead, this supports larger states (over 2GB)")]
public byte[] GetStateData()
{
var stateSize = NativeApi.llama_get_state_size(_ctx);
@ -125,6 +129,44 @@ namespace LLama
return stateMemory;
}
/// <summary>
/// Get the state data as an opaque handle
/// </summary>
/// <returns></returns>
public State GetState()
{
var stateSize = NativeApi.llama_get_state_size(_ctx);
unsafe
{
var bigMemory = Marshal.AllocHGlobal((nint)stateSize);
var smallMemory = IntPtr.Zero;
try
{
// Copy the state data into "big memory", discover the actual size required
var actualSize = NativeApi.llama_copy_state_data(_ctx, (byte*)bigMemory);
// Allocate a smaller buffer
smallMemory = Marshal.AllocHGlobal((nint)actualSize);
// Copy into the smaller buffer and free the large one to save excess memory usage
Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize);
Marshal.FreeHGlobal(bigMemory);
bigMemory = IntPtr.Zero;
return new State(smallMemory);
}
catch
{
if (bigMemory != IntPtr.Zero)
Marshal.FreeHGlobal(bigMemory);
if (smallMemory != IntPtr.Zero)
Marshal.FreeHGlobal(smallMemory);
throw;
}
}
}
/// <summary>
/// Load the state from specified path.
/// </summary>
@ -161,6 +203,19 @@ namespace LLama
NativeApi.llama_set_state_data(_ctx, stateData);
}
/// <summary>
/// Load the state from memory.
/// </summary>
/// <param name="state"></param>
/// <exception cref="RuntimeError"></exception>
public void LoadState(State state)
{
unsafe
{
NativeApi.llama_set_state_data(_ctx, (byte*)state.DangerousGetHandle().ToPointer());
}
}
/// <summary>
/// Perform the sampling. Please don't use it unless you fully know what it does.
/// </summary>
@ -304,12 +359,30 @@ namespace LLama
}
}
/// <summary>
///
/// </summary>
public void Dispose()
/// <inheritdoc />
public virtual void Dispose()
{
_ctx.Dispose();
}
/// <summary>
/// The state of this model, which can be reloaded later
/// </summary>
public class State
: SafeHandleZeroOrMinusOneIsInvalid
{
internal State(IntPtr memory)
: base(true)
{
SetHandle(memory);
}
/// <inheritdoc />
protected override bool ReleaseHandle()
{
Marshal.FreeHGlobal(handle);
return true;
}
}
}
}

View File

@ -3,10 +3,8 @@ using LLama.Common;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
namespace LLama
@ -19,7 +17,7 @@ namespace LLama
public class StatelessExecutor : ILLamaExecutor
{
private LLamaModel _model;
private byte[] _originalState;
private LLamaModel.State _originalState;
/// <summary>
/// The mode used by the executor when running the inference.
/// </summary>
@ -33,7 +31,7 @@ namespace LLama
_model = model;
var tokens = model.Tokenize(" ", true);
Utils.Eval(_model.NativeHandle, tokens.ToArray(), 0, tokens.Count(), 0, _model.Params.Threads);
_originalState = model.GetStateData();
_originalState = model.GetState();
}
/// <inheritdoc />

View File

@ -13,7 +13,7 @@ namespace LLama
/// <summary>
/// The initial state of the model
/// </summary>
public byte[] OriginalState { get; set; }
public State OriginalState { get; set; }
/// <summary>
///
/// </summary>
@ -21,7 +21,7 @@ namespace LLama
/// <param name="encoding"></param>
public ResettableLLamaModel(ModelParams Params, string encoding = "UTF-8") : base(Params, encoding)
{
OriginalState = GetStateData();
OriginalState = GetState();
}
/// <summary>
@ -31,5 +31,13 @@ namespace LLama
{
LoadState(OriginalState);
}
/// <inheritdoc />
public override void Dispose()
{
OriginalState.Dispose();
base.Dispose();
}
}
}