parent
a09aa86324
commit
d3b8ee988c
|
@ -52,4 +52,4 @@ jobs:
|
|||
- name: Build
|
||||
run: dotnet build LLamaSharp.sln -c ${{ matrix.config }} --no-restore
|
||||
- name: Test
|
||||
run: dotnet test LLamaSharp.sln -c ${{ matrix.config }}
|
||||
run: dotnet test LLamaSharp.sln -c ${{ matrix.config }} -l "console;verbosity=detailed"
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
using System.Text;
|
||||
using LLama.Common;
|
||||
using LLama.Native;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
namespace LLama.Unittest;
|
||||
|
||||
public sealed class BeamTests
|
||||
: IDisposable
|
||||
{
|
||||
private readonly ITestOutputHelper _testOutputHelper;
|
||||
private readonly ModelParams _params;
|
||||
private readonly LLamaWeights _model;
|
||||
|
||||
public BeamTests(ITestOutputHelper testOutputHelper)
|
||||
{
|
||||
_testOutputHelper = testOutputHelper;
|
||||
_params = new ModelParams(Constants.ModelPath)
|
||||
{
|
||||
ContextSize = 2048
|
||||
};
|
||||
_model = LLamaWeights.LoadFromFile(_params);
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
_model.Dispose();
|
||||
}
|
||||
|
||||
[Fact(Skip = "Very very slow in CI")]
|
||||
public void BasicBeam()
|
||||
{
|
||||
const int num_beams = 2;
|
||||
const int n_predict = 3;
|
||||
|
||||
var context = _model.CreateContext(_params);
|
||||
|
||||
var result = new StringBuilder();
|
||||
|
||||
var initial_tokens = context.Tokenize("The cat sat on");
|
||||
result.Append(context.DeTokenize(initial_tokens.ToArray()));
|
||||
context.Eval(initial_tokens, 0);
|
||||
|
||||
NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
|
||||
{
|
||||
for (var i = 0; i < state.Beams.Length; i++)
|
||||
{
|
||||
ref var view = ref state.Beams[i];
|
||||
var tokens = context.DeTokenize(view.Tokens.ToArray());
|
||||
_testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'");
|
||||
}
|
||||
|
||||
if (state.CommonPrefixLength > 0)
|
||||
{
|
||||
var view = state.Beams[0];
|
||||
result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray()));
|
||||
}
|
||||
|
||||
}, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2));
|
||||
|
||||
_testOutputHelper.WriteLine($"Final: {result}");
|
||||
}
|
||||
}
|
|
@ -66,7 +66,7 @@ namespace LLama.Unittest
|
|||
Grammar = grammar,
|
||||
};
|
||||
|
||||
var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList();
|
||||
var result = executor.Infer("Q. 7 + 12\nA. ", inferenceParams).ToList();
|
||||
|
||||
Assert.Equal("cat", result[0]);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace LLama.Native;
|
||||
|
||||
using llama_token = Int32;
|
||||
|
||||
/// <summary>
|
||||
/// Information about a single beam in a beam search
|
||||
/// </summary>
|
||||
[StructLayout(LayoutKind.Sequential)]
|
||||
public struct LLamaBeamView
|
||||
{
|
||||
private readonly unsafe llama_token* tokens;
|
||||
private readonly nint n_tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Cumulative beam probability (renormalized relative to all beams)
|
||||
/// </summary>
|
||||
public readonly float CumulativeProbability;
|
||||
|
||||
/// <summary>
|
||||
/// Callback should set this to true when a beam is at end-of-beam.
|
||||
/// </summary>
|
||||
public bool EndOfBeam;
|
||||
|
||||
/// <summary>
|
||||
/// Tokens in this beam
|
||||
/// </summary>
|
||||
public readonly Span<llama_token> Tokens
|
||||
{
|
||||
get
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
if (n_tokens > int.MaxValue)
|
||||
throw new InvalidOperationException("More than 2147483647 tokens is not supported");
|
||||
return new Span<llama_token>(tokens, (int)n_tokens);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace LLama.Native;
|
||||
|
||||
/// <summary>
|
||||
/// Passed to beam_search_callback function.
|
||||
/// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
|
||||
/// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
|
||||
/// </summary>
|
||||
[StructLayout(LayoutKind.Sequential)]
|
||||
public readonly struct LLamaBeamsState
|
||||
{
|
||||
/// <summary>
|
||||
/// The state of each individual beam
|
||||
/// </summary>
|
||||
private readonly unsafe LLamaBeamView* beam_views;
|
||||
|
||||
/// <summary>
|
||||
/// Number of elements in beam_views
|
||||
/// </summary>
|
||||
private readonly nint n_beams;
|
||||
|
||||
/// <summary>
|
||||
/// Current max length of prefix tokens shared by all beams.
|
||||
/// </summary>
|
||||
public readonly ulong CommonPrefixLength;
|
||||
|
||||
/// <summary>
|
||||
/// True iff this is the last callback invocation.
|
||||
/// </summary>
|
||||
public readonly bool LastCall;
|
||||
|
||||
/// <summary>
|
||||
/// The current state of each beam
|
||||
/// </summary>
|
||||
public Span<LLamaBeamView> Beams
|
||||
{
|
||||
get
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
if (n_beams > int.MaxValue)
|
||||
throw new InvalidOperationException("More than 2147483647 beams is not supported");
|
||||
return new Span<LLamaBeamView>(beam_views, (int)n_beams);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace LLama.Native;
|
||||
|
||||
public partial class NativeApi
|
||||
{
|
||||
/// <summary>
|
||||
/// Type of pointer to the beam_search_callback function.
|
||||
/// </summary>
|
||||
/// <param name="callback_data">callback_data is any custom data passed to llama_beam_search, that is subsequently passed back to beam_search_callbac</param>
|
||||
/// <param name="state"></param>
|
||||
public delegate void LLamaBeamSearchCallback(IntPtr callback_data, LLamaBeamsState state);
|
||||
|
||||
/// <summary>Deterministically returns entire sentence constructed by a beam search.</summary>
|
||||
/// <param name="ctx">Pointer to the llama_context.</param>
|
||||
/// <param name="callback">Invoked for each iteration of the beam_search loop, passing in beams_state.</param>
|
||||
/// <param name="callback_data">A pointer that is simply passed back to callback.</param>
|
||||
/// <param name="n_beams">Number of beams to use.</param>
|
||||
/// <param name="n_past">Number of tokens already evaluated.</param>
|
||||
/// <param name="n_predict">Maximum number of tokens to predict. EOS may occur earlier.</param>
|
||||
/// <param name="n_threads">Number of threads.</param>
|
||||
[DllImport(libraryName, EntryPoint = "llama_beam_search", CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern void llama_beam_search(SafeLLamaContextHandle ctx, LLamaBeamSearchCallback callback, IntPtr callback_data, ulong n_beams, int n_past, int n_predict, int n_threads);
|
||||
}
|
Loading…
Reference in New Issue