Merge pull request #102 from martindevans/grammar_basics
Grammar basics
This commit is contained in:
commit
976ca6a740
|
@ -1,14 +1,31 @@
|
|||
using System.Text;
|
||||
using LLama.Common;
|
||||
|
||||
namespace LLama.Unittest
|
||||
{
|
||||
public class BasicTest
|
||||
: IDisposable
|
||||
{
|
||||
[Fact]
|
||||
public void LoadModel()
|
||||
private readonly ModelParams _params;
|
||||
private readonly LLamaWeights _model;
|
||||
|
||||
public BasicTest()
|
||||
{
|
||||
var model = new LLamaContext(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin", contextSize: 256));
|
||||
model.Dispose();
|
||||
_params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin", contextSize: 2048);
|
||||
_model = LLamaWeights.LoadFromFile(_params);
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
_model.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BasicModelProperties()
|
||||
{
|
||||
Assert.Equal(32000, _model.VocabCount);
|
||||
Assert.Equal(2048, _model.ContextSize);
|
||||
Assert.Equal(4096, _model.EmbeddingSize);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
using System.Text;
|
||||
using LLama.Common;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Unittest
|
||||
{
|
||||
public sealed class GrammarTest
|
||||
: IDisposable
|
||||
{
|
||||
private readonly ModelParams _params;
|
||||
private readonly LLamaWeights _model;
|
||||
|
||||
public GrammarTest()
|
||||
{
|
||||
_params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin", contextSize: 2048);
|
||||
_model = LLamaWeights.LoadFromFile(_params);
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
_model.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateBasicGrammar()
|
||||
{
|
||||
var rules = new List<List<LLamaGrammarElement>>
|
||||
{
|
||||
new()
|
||||
{
|
||||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'),
|
||||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'),
|
||||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
|
||||
},
|
||||
};
|
||||
|
||||
using var handle = SafeLLamaGrammarHandle.Create(rules, 0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SampleWithTrivialGrammar()
|
||||
{
|
||||
// Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so
|
||||
// we can be confident it's not what the LLM would say if not constrained by the grammar!
|
||||
var rules = new List<List<LLamaGrammarElement>>
|
||||
{
|
||||
new()
|
||||
{
|
||||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'c'),
|
||||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'),
|
||||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 't'),
|
||||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
|
||||
},
|
||||
};
|
||||
|
||||
using var grammar = SafeLLamaGrammarHandle.Create(rules, 0);
|
||||
|
||||
var executor = new StatelessExecutor(_model, _params);
|
||||
var inferenceParams = new InferenceParams
|
||||
{
|
||||
MaxTokens = 3,
|
||||
AntiPrompts = new [] { ".", "Input:", "\n" },
|
||||
Grammar = grammar,
|
||||
};
|
||||
|
||||
var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList();
|
||||
|
||||
Assert.Equal("cat", result[0]);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
using LLama.Common;
|
||||
using LLama.Abstractions;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Web.Common
|
||||
{
|
||||
|
@ -95,5 +96,10 @@ namespace LLama.Web.Common
|
|||
/// consider newlines as a repeatable token (penalize_nl)
|
||||
/// </summary>
|
||||
public bool PenalizeNL { get; set; } = true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A grammar to constrain possible tokens
|
||||
/// </summary>
|
||||
public SafeLLamaGrammarHandle Grammar { get; set; } = null;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
using System.Collections.Generic;
|
||||
using LLama.Common;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Abstractions
|
||||
{
|
||||
|
@ -113,5 +114,10 @@ namespace LLama.Abstractions
|
|||
/// consider newlines as a repeatable token (penalize_nl)
|
||||
/// </summary>
|
||||
public bool PenalizeNL { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Grammar to constrain possible tokens
|
||||
/// </summary>
|
||||
SafeLLamaGrammarHandle? Grammar { get; set; }
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
using LLama.Abstractions;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Common
|
||||
{
|
||||
|
@ -96,6 +97,11 @@ namespace LLama.Common
|
|||
/// consider newlines as a repeatable token (penalize_nl)
|
||||
/// </summary>
|
||||
public bool PenalizeNL { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// A grammar to constrain the possible tokens
|
||||
/// </summary>
|
||||
public SafeLLamaGrammarHandle? Grammar { get; set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -291,11 +291,19 @@ namespace LLama
|
|||
/// <param name="topP"></param>
|
||||
/// <param name="tfsZ"></param>
|
||||
/// <param name="typicalP"></param>
|
||||
/// <param name="grammar"></param>
|
||||
/// <returns></returns>
|
||||
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
|
||||
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
|
||||
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f,
|
||||
SafeLLamaGrammarHandle? grammar = null)
|
||||
{
|
||||
llama_token id;
|
||||
|
||||
if (grammar != null)
|
||||
{
|
||||
SamplingApi.llama_sample_grammar(_ctx, candidates, grammar);
|
||||
}
|
||||
|
||||
if (temperature <= 0)
|
||||
{
|
||||
// Greedy sampling
|
||||
|
@ -329,6 +337,12 @@ namespace LLama
|
|||
}
|
||||
mirostat_mu = mu;
|
||||
}
|
||||
|
||||
if (grammar != null)
|
||||
{
|
||||
NativeApi.llama_grammar_accept_token(_ctx, grammar, id);
|
||||
}
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
|
|
|
@ -217,7 +217,8 @@ namespace LLama
|
|||
var mu = MirostatMu;
|
||||
var id = Context.Sample(
|
||||
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
|
||||
inferenceParams.Grammar
|
||||
);
|
||||
MirostatMu = mu;
|
||||
|
||||
|
|
|
@ -206,7 +206,8 @@ namespace LLama
|
|||
var mu = MirostatMu;
|
||||
var id = Context.Sample(
|
||||
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
|
||||
inferenceParams.Grammar
|
||||
);
|
||||
MirostatMu = mu;
|
||||
|
||||
|
|
|
@ -101,7 +101,7 @@ namespace LLama
|
|||
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
|
||||
|
||||
var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);
|
||||
|
||||
lastTokens.Add(id);
|
||||
|
||||
|
|
|
@ -25,6 +25,21 @@ namespace LLama
|
|||
/// </summary>
|
||||
public Encoding Encoding { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Total number of tokens in vocabulary of this model
|
||||
/// </summary>
|
||||
public int VocabCount => NativeHandle.VocabCount;
|
||||
|
||||
/// <summary>
|
||||
/// Total number of tokens in the context
|
||||
/// </summary>
|
||||
public int ContextSize => NativeHandle.ContextSize;
|
||||
|
||||
/// <summary>
|
||||
/// Dimension of embedding vectors
|
||||
/// </summary>
|
||||
public int EmbeddingSize => NativeHandle.EmbeddingSize;
|
||||
|
||||
internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding)
|
||||
{
|
||||
_weights = weights;
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace LLama.Native
|
||||
{
|
||||
/// <summary>
|
||||
/// grammar element type
|
||||
/// </summary>
|
||||
public enum LLamaGrammarElementType
|
||||
{
|
||||
/// <summary>
|
||||
/// end of rule definition
|
||||
/// </summary>
|
||||
END = 0,
|
||||
|
||||
/// <summary>
|
||||
/// start of alternate definition for rule
|
||||
/// </summary>
|
||||
ALT = 1,
|
||||
|
||||
/// <summary>
|
||||
/// non-terminal element: reference to rule
|
||||
/// </summary>
|
||||
RULE_REF = 2,
|
||||
|
||||
/// <summary>
|
||||
/// terminal element: character (code point)
|
||||
/// </summary>
|
||||
CHAR = 3,
|
||||
|
||||
/// <summary>
|
||||
/// inverse char(s) ([^a], [^a-b] [^abc])
|
||||
/// </summary>
|
||||
CHAR_NOT = 4,
|
||||
|
||||
/// <summary>
|
||||
/// modifies a preceding CHAR or CHAR_ALT to
|
||||
/// be an inclusive range ([a-z])
|
||||
/// </summary>
|
||||
CHAR_RNG_UPPER = 5,
|
||||
|
||||
/// <summary>
|
||||
/// modifies a preceding CHAR or
|
||||
/// CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
||||
/// </summary>
|
||||
CHAR_ALT = 6,
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// An element of a grammar
|
||||
/// </summary>
|
||||
[StructLayout(LayoutKind.Sequential)]
|
||||
public struct LLamaGrammarElement
|
||||
{
|
||||
/// <summary>
|
||||
/// The type of this element
|
||||
/// </summary>
|
||||
public LLamaGrammarElementType Type;
|
||||
|
||||
/// <summary>
|
||||
/// Unicode code point or rule ID
|
||||
/// </summary>
|
||||
public uint Value;
|
||||
|
||||
/// <summary>
|
||||
/// Construct a new LLamaGrammarElement
|
||||
/// </summary>
|
||||
/// <param name="type"></param>
|
||||
/// <param name="value"></param>
|
||||
public LLamaGrammarElement(LLamaGrammarElementType type, uint value)
|
||||
{
|
||||
Type = type;
|
||||
Value = value;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace LLama.Native
|
||||
{
|
||||
using llama_token = Int32;
|
||||
|
||||
public unsafe partial class NativeApi
|
||||
{
|
||||
/// <summary>
|
||||
/// Create a new grammar from the given set of grammar rules
|
||||
/// </summary>
|
||||
/// <param name="rules"></param>
|
||||
/// <param name="n_rules"></param>
|
||||
/// <param name="start_rule_index"></param>
|
||||
/// <returns></returns>
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index);
|
||||
|
||||
/// <summary>
|
||||
/// Free all memory from the given SafeLLamaGrammarHandle
|
||||
/// </summary>
|
||||
/// <param name="grammar"></param>
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern void llama_grammar_free(IntPtr grammar);
|
||||
|
||||
/// <summary>
|
||||
/// Apply constraints from grammar
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="candidates"></param>
|
||||
/// <param name="grammar"></param>
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern void llama_sample_grammar(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, SafeLLamaGrammarHandle grammar);
|
||||
|
||||
/// <summary>
|
||||
/// Accepts the sampled token into the grammar
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="grammar"></param>
|
||||
/// <param name="token"></param>
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern void llama_grammar_accept_token(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar, llama_token token);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Linq;
|
||||
using LLama.Exceptions;
|
||||
|
||||
namespace LLama.Native
|
||||
{
|
||||
/// <summary>
|
||||
/// A safe reference to a `llama_grammar`
|
||||
/// </summary>
|
||||
public class SafeLLamaGrammarHandle
|
||||
: SafeLLamaHandleBase
|
||||
{
|
||||
#region construction/destruction
|
||||
/// <summary>
|
||||
///
|
||||
/// </summary>
|
||||
/// <param name="handle"></param>
|
||||
internal SafeLLamaGrammarHandle(IntPtr handle)
|
||||
: base(handle)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override bool ReleaseHandle()
|
||||
{
|
||||
NativeApi.llama_grammar_free(handle);
|
||||
SetHandle(IntPtr.Zero);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a new llama_grammar
|
||||
/// </summary>
|
||||
/// <param name="rules">A list of list of elements, each inner list makes up one grammar rule</param>
|
||||
/// <param name="start_rule_index">The index (in the outer list) of the start rule</param>
|
||||
/// <returns></returns>
|
||||
/// <exception cref="RuntimeError"></exception>
|
||||
public static SafeLLamaGrammarHandle Create(IReadOnlyList<IReadOnlyList<LLamaGrammarElement>> rules, ulong start_rule_index)
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
var totalElements = rules.Sum(a => a.Count);
|
||||
var nrules = (ulong)rules.Count;
|
||||
|
||||
// Borrow an array large enough to hold every single element
|
||||
// and another array large enough to hold a pointer to each rule
|
||||
var allElements = ArrayPool<LLamaGrammarElement>.Shared.Rent(totalElements);
|
||||
var pointers = ArrayPool<IntPtr>.Shared.Rent(rules.Count);
|
||||
try
|
||||
{
|
||||
fixed (LLamaGrammarElement* allElementsPtr = allElements)
|
||||
{
|
||||
var elementIndex = 0;
|
||||
var pointerIndex = 0;
|
||||
foreach (var rule in rules)
|
||||
{
|
||||
// Save a pointer to the start of this rule
|
||||
pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex);
|
||||
|
||||
// Copy all of the rule elements into the flat array
|
||||
foreach (var element in rule)
|
||||
allElementsPtr[elementIndex++] = element;
|
||||
}
|
||||
|
||||
// Sanity check some things that should be true if the copy worked as planned
|
||||
Debug.Assert((ulong)pointerIndex == nrules);
|
||||
Debug.Assert(elementIndex == totalElements);
|
||||
|
||||
// Make the actual call through to llama.cpp
|
||||
fixed (void* ptr = pointers)
|
||||
{
|
||||
return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
ArrayPool<LLamaGrammarElement>.Shared.Return(allElements);
|
||||
ArrayPool<IntPtr>.Shared.Return(pointers);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a new llama_grammar
|
||||
/// </summary>
|
||||
/// <param name="rules">rules list, each rule is a list of rule elements (terminated by a LLamaGrammarElementType.END element)</param>
|
||||
/// <param name="nrules">total number of rules</param>
|
||||
/// <param name="start_rule_index">index of the start rule of the grammar</param>
|
||||
/// <returns></returns>
|
||||
/// <exception cref="RuntimeError"></exception>
|
||||
public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, ulong nrules, ulong start_rule_index)
|
||||
{
|
||||
var grammar_ptr = NativeApi.llama_grammar_init(rules, nrules, start_rule_index);
|
||||
if (grammar_ptr == IntPtr.Zero)
|
||||
throw new RuntimeError("Failed to create grammar from rules");
|
||||
|
||||
return new(grammar_ptr);
|
||||
}
|
||||
#endregion
|
||||
}
|
||||
}
|
|
@ -5,6 +5,18 @@ namespace LLama.Native
|
|||
using llama_token = Int32;
|
||||
public unsafe class SamplingApi
|
||||
{
|
||||
/// <summary>
|
||||
/// Apply grammar rules to candidate tokens
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="candidates"></param>
|
||||
/// <param name="grammar"></param>
|
||||
public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar)
|
||||
{
|
||||
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
|
||||
NativeApi.llama_sample_grammar(ctx, ref st, grammar);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
/// </summary>
|
||||
|
|
Loading…
Reference in New Issue