- Added `SamplingPipeline` to inference params which overrides all other options with an entirely custom pipeline.

- Added a `Sample` method to `LLamaContext` which uses a custom pipeline
 - Modified all executors to use the custom pipeline if it exists
This commit is contained in:
Martin Evans 2023-12-08 01:02:27 +00:00
parent 33358124db
commit b34f72a883
8 changed files with 121 additions and 33 deletions

View File

@ -1,6 +1,9 @@
using LLama.Common;
#nullable enable
using LLama.Common;
using LLama.Abstractions;
using LLama.Native;
using LLama.Sampling;
namespace LLama.Web.Common
{
@ -64,6 +67,9 @@ namespace LLama.Web.Common
/// <summary>
/// A grammar to constrain possible tokens
/// </summary>
public SafeLLamaGrammarHandle Grammar { get; set; } = null;
public SafeLLamaGrammarHandle? Grammar { get; set; }
/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }
}
}

View File

@ -1,6 +1,7 @@
using System.Collections.Generic;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;
namespace LLama.Abstractions
{
@ -108,5 +109,10 @@ namespace LLama.Abstractions
/// Grammar to constrain possible tokens
/// </summary>
SafeLLamaGrammarHandle? Grammar { get; set; }
/// <summary>
/// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b>
/// </summary>
ISamplingPipeline? SamplingPipeline { get; set; }
}
}

View File

@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using LLama.Native;
using LLama.Sampling;
namespace LLama.Common
{
@ -76,6 +77,9 @@ namespace LLama.Common
/// <inheritdoc />
public SafeLLamaGrammarHandle? Grammar { get; set; }
/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }
}
/// <summary>

View File

@ -10,6 +10,7 @@ using LLama.Common;
using System.Runtime.InteropServices;
using LLama.Extensions;
using LLama.Abstractions;
using LLama.Sampling;
using Microsoft.Extensions.Logging;
namespace LLama
@ -212,6 +213,17 @@ namespace LLama
}
}
/// <summary>
/// Sample a single token from this context, using the given sampling pipeline
/// </summary>
/// <param name="pipeline">The pipeline to use to process the logits and to select a token</param>
/// <param name="lastTokens">The tokens recently returned from the model</param>
/// <returns>The selected token</returns>
public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens)
{
return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
}
/// <summary>
/// Perform the sampling. Please don't use it unless you fully know what it does.
/// </summary>

View File

@ -210,16 +210,24 @@ namespace LLama
SaveSessionFile(_pathSession);
}
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
var mu = MirostatMu;
var id = Context.Sample(
id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
}
_last_n_tokens.Enqueue(id);

View File

@ -189,16 +189,24 @@ namespace LLama
SaveSessionFile(_pathSession);
}
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
var mu = MirostatMu;
var id = Context.Sample(
id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
}
_last_n_tokens.Enqueue(id);

View File

@ -7,6 +7,7 @@ using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using LLama.Native;
using LLama.Sampling;
using Microsoft.Extensions.Logging;
namespace LLama
@ -84,17 +85,25 @@ namespace LLama
var mu = (float?)null;
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
{
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
}
else
{
// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
// Sample a single token
var id = Context.Sample(
id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
}
// Decode this token into text
decoder.Add(id);

View File

@ -1,5 +1,7 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using LLama.Native;
using LLama.Sampling.Logits;
using LLama.Sampling.Selection;
@ -16,9 +18,9 @@ public interface ISamplingPipeline
/// <summary>
/// Sample a single token from the given logits
/// </summary>
/// <param name="ctx"></param>
/// <param name="logits"></param>
/// <param name="lastTokens"></param>
/// <param name="ctx">The context being sampled from</param>
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A span of tokens recently returned by the model</param>
/// <returns></returns>
int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens);
@ -28,10 +30,43 @@ public interface ISamplingPipeline
void Reset();
}
/// <summary>
/// Extensions methods for ISamplingPipeline
/// </summary>
public static class ISamplingPipelineExtensions
{
/// <summary>
/// Sample a single token from the given logits
/// </summary>
/// <param name="pipeline"></param>
/// <param name="ctx">The context being sampled from</param>
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
/// <returns></returns>
public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<int> lastTokens)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(lastTokens);
return pipeline.Sample(ctx, logits, span);
#else
var copy = ArrayPool<int>.Shared.Rent(lastTokens.Count);
try
{
lastTokens.CopyTo(copy);
return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length));
}
finally
{
ArrayPool<int>.Shared.Return(copy);
}
#endif
}
}
/// <summary>
/// Simple implementation of `ISamplingPipeline`, applies processors in order every time
/// </summary>
public sealed class BasicSamplingPipeline
public sealed class ConfigurableSamplingPipeline
: ISamplingPipeline
{
/// <summary>