From d3b8ee988cd538e1b76347e6644dd10b86b078d8 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Thu, 7 Sep 2023 19:26:51 +0100 Subject: [PATCH] Beam Search (#155) * Added the low level bindings to beam search. --- .github/workflows/main.yml | 2 +- LLama.Unittest/BeamTests.cs | 63 ++++++++++++++++++++++++++++ LLama.Unittest/GrammarTest.cs | 2 +- LLama/Native/LLamaBeamView.cs | 42 +++++++++++++++++++ LLama/Native/LLamaBeamsState.cs | 49 ++++++++++++++++++++++ LLama/Native/NativeApi.BeamSearch.cs | 25 +++++++++++ 6 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 LLama.Unittest/BeamTests.cs create mode 100644 LLama/Native/LLamaBeamView.cs create mode 100644 LLama/Native/LLamaBeamsState.cs create mode 100644 LLama/Native/NativeApi.BeamSearch.cs diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4322a0a1..97760307 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -52,4 +52,4 @@ jobs: - name: Build run: dotnet build LLamaSharp.sln -c ${{ matrix.config }} --no-restore - name: Test - run: dotnet test LLamaSharp.sln -c ${{ matrix.config }} + run: dotnet test LLamaSharp.sln -c ${{ matrix.config }} -l "console;verbosity=detailed" diff --git a/LLama.Unittest/BeamTests.cs b/LLama.Unittest/BeamTests.cs new file mode 100644 index 00000000..f8d5cf01 --- /dev/null +++ b/LLama.Unittest/BeamTests.cs @@ -0,0 +1,63 @@ +using System.Text; +using LLama.Common; +using LLama.Native; +using Xunit.Abstractions; + +namespace LLama.Unittest; + +public sealed class BeamTests + : IDisposable +{ + private readonly ITestOutputHelper _testOutputHelper; + private readonly ModelParams _params; + private readonly LLamaWeights _model; + + public BeamTests(ITestOutputHelper testOutputHelper) + { + _testOutputHelper = testOutputHelper; + _params = new ModelParams(Constants.ModelPath) + { + ContextSize = 2048 + }; + _model = LLamaWeights.LoadFromFile(_params); + } + + public void Dispose() + { + _model.Dispose(); + } + + [Fact(Skip = "Very very slow in CI")] + public void BasicBeam() + { + const int num_beams = 2; + const int n_predict = 3; + + var context = _model.CreateContext(_params); + + var result = new StringBuilder(); + + var initial_tokens = context.Tokenize("The cat sat on"); + result.Append(context.DeTokenize(initial_tokens.ToArray())); + context.Eval(initial_tokens, 0); + + NativeApi.llama_beam_search(context.NativeHandle, (data, state) => + { + for (var i = 0; i < state.Beams.Length; i++) + { + ref var view = ref state.Beams[i]; + var tokens = context.DeTokenize(view.Tokens.ToArray()); + _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); + } + + if (state.CommonPrefixLength > 0) + { + var view = state.Beams[0]; + result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray())); + } + + }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); + + _testOutputHelper.WriteLine($"Final: {result}"); + } +} \ No newline at end of file diff --git a/LLama.Unittest/GrammarTest.cs b/LLama.Unittest/GrammarTest.cs index 152ede93..b86a0f40 100644 --- a/LLama.Unittest/GrammarTest.cs +++ b/LLama.Unittest/GrammarTest.cs @@ -66,7 +66,7 @@ namespace LLama.Unittest Grammar = grammar, }; - var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList(); + var result = executor.Infer("Q. 7 + 12\nA. ", inferenceParams).ToList(); Assert.Equal("cat", result[0]); } diff --git a/LLama/Native/LLamaBeamView.cs b/LLama/Native/LLamaBeamView.cs new file mode 100644 index 00000000..e6a6c39f --- /dev/null +++ b/LLama/Native/LLamaBeamView.cs @@ -0,0 +1,42 @@ +using System; +using System.Runtime.InteropServices; + +namespace LLama.Native; + +using llama_token = Int32; + +/// +/// Information about a single beam in a beam search +/// +[StructLayout(LayoutKind.Sequential)] +public struct LLamaBeamView +{ + private readonly unsafe llama_token* tokens; + private readonly nint n_tokens; + + /// + /// Cumulative beam probability (renormalized relative to all beams) + /// + public readonly float CumulativeProbability; + + /// + /// Callback should set this to true when a beam is at end-of-beam. + /// + public bool EndOfBeam; + + /// + /// Tokens in this beam + /// + public readonly Span Tokens + { + get + { + unsafe + { + if (n_tokens > int.MaxValue) + throw new InvalidOperationException("More than 2147483647 tokens is not supported"); + return new Span(tokens, (int)n_tokens); + } + } + } +} \ No newline at end of file diff --git a/LLama/Native/LLamaBeamsState.cs b/LLama/Native/LLamaBeamsState.cs new file mode 100644 index 00000000..6f0a447d --- /dev/null +++ b/LLama/Native/LLamaBeamsState.cs @@ -0,0 +1,49 @@ +using System; +using System.Runtime.InteropServices; + +namespace LLama.Native; + +/// +/// Passed to beam_search_callback function. +/// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams +/// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. +/// +[StructLayout(LayoutKind.Sequential)] +public readonly struct LLamaBeamsState +{ + /// + /// The state of each individual beam + /// + private readonly unsafe LLamaBeamView* beam_views; + + /// + /// Number of elements in beam_views + /// + private readonly nint n_beams; + + /// + /// Current max length of prefix tokens shared by all beams. + /// + public readonly ulong CommonPrefixLength; + + /// + /// True iff this is the last callback invocation. + /// + public readonly bool LastCall; + + /// + /// The current state of each beam + /// + public Span Beams + { + get + { + unsafe + { + if (n_beams > int.MaxValue) + throw new InvalidOperationException("More than 2147483647 beams is not supported"); + return new Span(beam_views, (int)n_beams); + } + } + } +} \ No newline at end of file diff --git a/LLama/Native/NativeApi.BeamSearch.cs b/LLama/Native/NativeApi.BeamSearch.cs new file mode 100644 index 00000000..1049dbe3 --- /dev/null +++ b/LLama/Native/NativeApi.BeamSearch.cs @@ -0,0 +1,25 @@ +using System; +using System.Runtime.InteropServices; + +namespace LLama.Native; + +public partial class NativeApi +{ + /// + /// Type of pointer to the beam_search_callback function. + /// + /// callback_data is any custom data passed to llama_beam_search, that is subsequently passed back to beam_search_callbac + /// + public delegate void LLamaBeamSearchCallback(IntPtr callback_data, LLamaBeamsState state); + + /// Deterministically returns entire sentence constructed by a beam search. + /// Pointer to the llama_context. + /// Invoked for each iteration of the beam_search loop, passing in beams_state. + /// A pointer that is simply passed back to callback. + /// Number of beams to use. + /// Number of tokens already evaluated. + /// Maximum number of tokens to predict. EOS may occur earlier. + /// Number of threads. + [DllImport(libraryName, EntryPoint = "llama_beam_search", CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_beam_search(SafeLLamaContextHandle ctx, LLamaBeamSearchCallback callback, IntPtr callback_data, ulong n_beams, int n_past, int n_predict, int n_threads); +} \ No newline at end of file