Beam Search (#155)

* Added the low level bindings to beam search.
This commit is contained in:
Martin Evans 2023-09-07 19:26:51 +01:00 committed by GitHub
parent a09aa86324
commit d3b8ee988c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 181 additions and 2 deletions

View File

@ -52,4 +52,4 @@ jobs:
- name: Build
run: dotnet build LLamaSharp.sln -c ${{ matrix.config }} --no-restore
- name: Test
run: dotnet test LLamaSharp.sln -c ${{ matrix.config }}
run: dotnet test LLamaSharp.sln -c ${{ matrix.config }} -l "console;verbosity=detailed"

View File

@ -0,0 +1,63 @@
using System.Text;
using LLama.Common;
using LLama.Native;
using Xunit.Abstractions;
namespace LLama.Unittest;
public sealed class BeamTests
: IDisposable
{
private readonly ITestOutputHelper _testOutputHelper;
private readonly ModelParams _params;
private readonly LLamaWeights _model;
public BeamTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
_params = new ModelParams(Constants.ModelPath)
{
ContextSize = 2048
};
_model = LLamaWeights.LoadFromFile(_params);
}
public void Dispose()
{
_model.Dispose();
}
[Fact(Skip = "Very very slow in CI")]
public void BasicBeam()
{
const int num_beams = 2;
const int n_predict = 3;
var context = _model.CreateContext(_params);
var result = new StringBuilder();
var initial_tokens = context.Tokenize("The cat sat on");
result.Append(context.DeTokenize(initial_tokens.ToArray()));
context.Eval(initial_tokens, 0);
NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
{
for (var i = 0; i < state.Beams.Length; i++)
{
ref var view = ref state.Beams[i];
var tokens = context.DeTokenize(view.Tokens.ToArray());
_testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'");
}
if (state.CommonPrefixLength > 0)
{
var view = state.Beams[0];
result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray()));
}
}, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2));
_testOutputHelper.WriteLine($"Final: {result}");
}
}

View File

@ -66,7 +66,7 @@ namespace LLama.Unittest
Grammar = grammar,
};
var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList();
var result = executor.Infer("Q. 7 + 12\nA. ", inferenceParams).ToList();
Assert.Equal("cat", result[0]);
}

View File

@ -0,0 +1,42 @@
using System;
using System.Runtime.InteropServices;
namespace LLama.Native;
using llama_token = Int32;
/// <summary>
/// Information about a single beam in a beam search
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public struct LLamaBeamView
{
private readonly unsafe llama_token* tokens;
private readonly nint n_tokens;
/// <summary>
/// Cumulative beam probability (renormalized relative to all beams)
/// </summary>
public readonly float CumulativeProbability;
/// <summary>
/// Callback should set this to true when a beam is at end-of-beam.
/// </summary>
public bool EndOfBeam;
/// <summary>
/// Tokens in this beam
/// </summary>
public readonly Span<llama_token> Tokens
{
get
{
unsafe
{
if (n_tokens > int.MaxValue)
throw new InvalidOperationException("More than 2147483647 tokens is not supported");
return new Span<llama_token>(tokens, (int)n_tokens);
}
}
}
}

View File

@ -0,0 +1,49 @@
using System;
using System.Runtime.InteropServices;
namespace LLama.Native;
/// <summary>
/// Passed to beam_search_callback function.
/// Whenever 0 &lt; common_prefix_length, this number of tokens should be copied from any of the beams
/// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public readonly struct LLamaBeamsState
{
/// <summary>
/// The state of each individual beam
/// </summary>
private readonly unsafe LLamaBeamView* beam_views;
/// <summary>
/// Number of elements in beam_views
/// </summary>
private readonly nint n_beams;
/// <summary>
/// Current max length of prefix tokens shared by all beams.
/// </summary>
public readonly ulong CommonPrefixLength;
/// <summary>
/// True iff this is the last callback invocation.
/// </summary>
public readonly bool LastCall;
/// <summary>
/// The current state of each beam
/// </summary>
public Span<LLamaBeamView> Beams
{
get
{
unsafe
{
if (n_beams > int.MaxValue)
throw new InvalidOperationException("More than 2147483647 beams is not supported");
return new Span<LLamaBeamView>(beam_views, (int)n_beams);
}
}
}
}

View File

@ -0,0 +1,25 @@
using System;
using System.Runtime.InteropServices;
namespace LLama.Native;
public partial class NativeApi
{
/// <summary>
/// Type of pointer to the beam_search_callback function.
/// </summary>
/// <param name="callback_data">callback_data is any custom data passed to llama_beam_search, that is subsequently passed back to beam_search_callbac</param>
/// <param name="state"></param>
public delegate void LLamaBeamSearchCallback(IntPtr callback_data, LLamaBeamsState state);
/// <summary>Deterministically returns entire sentence constructed by a beam search.</summary>
/// <param name="ctx">Pointer to the llama_context.</param>
/// <param name="callback">Invoked for each iteration of the beam_search loop, passing in beams_state.</param>
/// <param name="callback_data">A pointer that is simply passed back to callback.</param>
/// <param name="n_beams">Number of beams to use.</param>
/// <param name="n_past">Number of tokens already evaluated.</param>
/// <param name="n_predict">Maximum number of tokens to predict. EOS may occur earlier.</param>
/// <param name="n_threads">Number of threads.</param>
[DllImport(libraryName, EntryPoint = "llama_beam_search", CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_beam_search(SafeLLamaContextHandle ctx, LLamaBeamSearchCallback callback, IntPtr callback_data, ulong n_beams, int n_past, int n_predict, int n_threads);
}