Merge pull request #102 from martindevans/grammar_basics

Grammar basics
This commit is contained in:
Martin Evans 2023-08-22 14:54:28 +01:00 committed by GitHub
commit 976ca6a740
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 383 additions and 9 deletions

View File

@ -1,14 +1,31 @@
using System.Text;
using LLama.Common;
namespace LLama.Unittest
{
public class BasicTest
: IDisposable
{
[Fact]
public void LoadModel()
private readonly ModelParams _params;
private readonly LLamaWeights _model;
public BasicTest()
{
var model = new LLamaContext(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin", contextSize: 256));
model.Dispose();
_params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin", contextSize: 2048);
_model = LLamaWeights.LoadFromFile(_params);
}
public void Dispose()
{
_model.Dispose();
}
[Fact]
public void BasicModelProperties()
{
Assert.Equal(32000, _model.VocabCount);
Assert.Equal(2048, _model.ContextSize);
Assert.Equal(4096, _model.EmbeddingSize);
}
}
}

View File

@ -0,0 +1,71 @@
using System.Text;
using LLama.Common;
using LLama.Native;
namespace LLama.Unittest
{
public sealed class GrammarTest
: IDisposable
{
private readonly ModelParams _params;
private readonly LLamaWeights _model;
public GrammarTest()
{
_params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin", contextSize: 2048);
_model = LLamaWeights.LoadFromFile(_params);
}
public void Dispose()
{
_model.Dispose();
}
[Fact]
public void CreateBasicGrammar()
{
var rules = new List<List<LLamaGrammarElement>>
{
new()
{
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
},
};
using var handle = SafeLLamaGrammarHandle.Create(rules, 0);
}
[Fact]
public void SampleWithTrivialGrammar()
{
// Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so
// we can be confident it's not what the LLM would say if not constrained by the grammar!
var rules = new List<List<LLamaGrammarElement>>
{
new()
{
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'c'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 't'),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
},
};
using var grammar = SafeLLamaGrammarHandle.Create(rules, 0);
var executor = new StatelessExecutor(_model, _params);
var inferenceParams = new InferenceParams
{
MaxTokens = 3,
AntiPrompts = new [] { ".", "Input:", "\n" },
Grammar = grammar,
};
var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList();
Assert.Equal("cat", result[0]);
}
}
}

View File

@ -1,5 +1,6 @@
using LLama.Common;
using LLama.Abstractions;
using LLama.Native;
namespace LLama.Web.Common
{
@ -95,5 +96,10 @@ namespace LLama.Web.Common
/// consider newlines as a repeatable token (penalize_nl)
/// </summary>
public bool PenalizeNL { get; set; } = true;
}
/// <summary>
/// A grammar to constrain possible tokens
/// </summary>
public SafeLLamaGrammarHandle Grammar { get; set; } = null;
}
}

View File

@ -1,5 +1,6 @@
using System.Collections.Generic;
using LLama.Common;
using LLama.Native;
namespace LLama.Abstractions
{
@ -113,5 +114,10 @@ namespace LLama.Abstractions
/// consider newlines as a repeatable token (penalize_nl)
/// </summary>
public bool PenalizeNL { get; set; }
/// <summary>
/// Grammar to constrain possible tokens
/// </summary>
SafeLLamaGrammarHandle? Grammar { get; set; }
}
}

View File

@ -1,6 +1,7 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using LLama.Native;
namespace LLama.Common
{
@ -96,6 +97,11 @@ namespace LLama.Common
/// consider newlines as a repeatable token (penalize_nl)
/// </summary>
public bool PenalizeNL { get; set; } = true;
/// <summary>
/// A grammar to constrain the possible tokens
/// </summary>
public SafeLLamaGrammarHandle? Grammar { get; set; }
}
/// <summary>

View File

@ -291,11 +291,19 @@ namespace LLama
/// <param name="topP"></param>
/// <param name="tfsZ"></param>
/// <param name="typicalP"></param>
/// <param name="grammar"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f,
SafeLLamaGrammarHandle? grammar = null)
{
llama_token id;
if (grammar != null)
{
SamplingApi.llama_sample_grammar(_ctx, candidates, grammar);
}
if (temperature <= 0)
{
// Greedy sampling
@ -329,6 +337,12 @@ namespace LLama
}
mirostat_mu = mu;
}
if (grammar != null)
{
NativeApi.llama_grammar_accept_token(_ctx, grammar, id);
}
return id;
}

View File

@ -217,7 +217,8 @@ namespace LLama
var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
inferenceParams.Grammar
);
MirostatMu = mu;

View File

@ -206,7 +206,8 @@ namespace LLama
var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
inferenceParams.Grammar
);
MirostatMu = mu;

View File

@ -101,7 +101,7 @@ namespace LLama
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);
lastTokens.Add(id);

View File

@ -25,6 +25,21 @@ namespace LLama
/// </summary>
public Encoding Encoding { get; }
/// <summary>
/// Total number of tokens in vocabulary of this model
/// </summary>
public int VocabCount => NativeHandle.VocabCount;
/// <summary>
/// Total number of tokens in the context
/// </summary>
public int ContextSize => NativeHandle.ContextSize;
/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize => NativeHandle.EmbeddingSize;
internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding)
{
_weights = weights;

View File

@ -0,0 +1,75 @@
using System.Runtime.InteropServices;
namespace LLama.Native
{
/// <summary>
/// grammar element type
/// </summary>
public enum LLamaGrammarElementType
{
/// <summary>
/// end of rule definition
/// </summary>
END = 0,
/// <summary>
/// start of alternate definition for rule
/// </summary>
ALT = 1,
/// <summary>
/// non-terminal element: reference to rule
/// </summary>
RULE_REF = 2,
/// <summary>
/// terminal element: character (code point)
/// </summary>
CHAR = 3,
/// <summary>
/// inverse char(s) ([^a], [^a-b] [^abc])
/// </summary>
CHAR_NOT = 4,
/// <summary>
/// modifies a preceding CHAR or CHAR_ALT to
/// be an inclusive range ([a-z])
/// </summary>
CHAR_RNG_UPPER = 5,
/// <summary>
/// modifies a preceding CHAR or
/// CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
/// </summary>
CHAR_ALT = 6,
};
/// <summary>
/// An element of a grammar
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public struct LLamaGrammarElement
{
/// <summary>
/// The type of this element
/// </summary>
public LLamaGrammarElementType Type;
/// <summary>
/// Unicode code point or rule ID
/// </summary>
public uint Value;
/// <summary>
/// Construct a new LLamaGrammarElement
/// </summary>
/// <param name="type"></param>
/// <param name="value"></param>
public LLamaGrammarElement(LLamaGrammarElementType type, uint value)
{
Type = type;
Value = value;
}
}
}

View File

@ -0,0 +1,45 @@
using System;
using System.Runtime.InteropServices;
namespace LLama.Native
{
using llama_token = Int32;
public unsafe partial class NativeApi
{
/// <summary>
/// Create a new grammar from the given set of grammar rules
/// </summary>
/// <param name="rules"></param>
/// <param name="n_rules"></param>
/// <param name="start_rule_index"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index);
/// <summary>
/// Free all memory from the given SafeLLamaGrammarHandle
/// </summary>
/// <param name="grammar"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_grammar_free(IntPtr grammar);
/// <summary>
/// Apply constraints from grammar
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates"></param>
/// <param name="grammar"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_grammar(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, SafeLLamaGrammarHandle grammar);
/// <summary>
/// Accepts the sampled token into the grammar
/// </summary>
/// <param name="ctx"></param>
/// <param name="grammar"></param>
/// <param name="token"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_grammar_accept_token(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar, llama_token token);
}
}

View File

@ -0,0 +1,105 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using LLama.Exceptions;
namespace LLama.Native
{
/// <summary>
/// A safe reference to a `llama_grammar`
/// </summary>
public class SafeLLamaGrammarHandle
: SafeLLamaHandleBase
{
#region construction/destruction
/// <summary>
///
/// </summary>
/// <param name="handle"></param>
internal SafeLLamaGrammarHandle(IntPtr handle)
: base(handle)
{
}
/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llama_grammar_free(handle);
SetHandle(IntPtr.Zero);
return true;
}
/// <summary>
/// Create a new llama_grammar
/// </summary>
/// <param name="rules">A list of list of elements, each inner list makes up one grammar rule</param>
/// <param name="start_rule_index">The index (in the outer list) of the start rule</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public static SafeLLamaGrammarHandle Create(IReadOnlyList<IReadOnlyList<LLamaGrammarElement>> rules, ulong start_rule_index)
{
unsafe
{
var totalElements = rules.Sum(a => a.Count);
var nrules = (ulong)rules.Count;
// Borrow an array large enough to hold every single element
// and another array large enough to hold a pointer to each rule
var allElements = ArrayPool<LLamaGrammarElement>.Shared.Rent(totalElements);
var pointers = ArrayPool<IntPtr>.Shared.Rent(rules.Count);
try
{
fixed (LLamaGrammarElement* allElementsPtr = allElements)
{
var elementIndex = 0;
var pointerIndex = 0;
foreach (var rule in rules)
{
// Save a pointer to the start of this rule
pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex);
// Copy all of the rule elements into the flat array
foreach (var element in rule)
allElementsPtr[elementIndex++] = element;
}
// Sanity check some things that should be true if the copy worked as planned
Debug.Assert((ulong)pointerIndex == nrules);
Debug.Assert(elementIndex == totalElements);
// Make the actual call through to llama.cpp
fixed (void* ptr = pointers)
{
return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index);
}
}
}
finally
{
ArrayPool<LLamaGrammarElement>.Shared.Return(allElements);
ArrayPool<IntPtr>.Shared.Return(pointers);
}
}
}
/// <summary>
/// Create a new llama_grammar
/// </summary>
/// <param name="rules">rules list, each rule is a list of rule elements (terminated by a LLamaGrammarElementType.END element)</param>
/// <param name="nrules">total number of rules</param>
/// <param name="start_rule_index">index of the start rule of the grammar</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, ulong nrules, ulong start_rule_index)
{
var grammar_ptr = NativeApi.llama_grammar_init(rules, nrules, start_rule_index);
if (grammar_ptr == IntPtr.Zero)
throw new RuntimeError("Failed to create grammar from rules");
return new(grammar_ptr);
}
#endregion
}
}

View File

@ -5,6 +5,18 @@ namespace LLama.Native
using llama_token = Int32;
public unsafe class SamplingApi
{
/// <summary>
/// Apply grammar rules to candidate tokens
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates"></param>
/// <param name="grammar"></param>
public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_grammar(ctx, ref st, grammar);
}
/// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// </summary>