- Added `SamplingPipeline` to inference params which overrides all other options with an entirely custom pipeline.
- Added a `Sample` method to `LLamaContext` which uses a custom pipeline - Modified all executors to use the custom pipeline if it exists
This commit is contained in:
parent
33358124db
commit
b34f72a883
|
@ -1,6 +1,9 @@
|
|||
using LLama.Common;
|
||||
#nullable enable
|
||||
|
||||
using LLama.Common;
|
||||
using LLama.Abstractions;
|
||||
using LLama.Native;
|
||||
using LLama.Sampling;
|
||||
|
||||
namespace LLama.Web.Common
|
||||
{
|
||||
|
@ -64,6 +67,9 @@ namespace LLama.Web.Common
|
|||
/// <summary>
|
||||
/// A grammar to constrain possible tokens
|
||||
/// </summary>
|
||||
public SafeLLamaGrammarHandle Grammar { get; set; } = null;
|
||||
public SafeLLamaGrammarHandle? Grammar { get; set; }
|
||||
|
||||
/// <inheritdoc />
|
||||
public ISamplingPipeline? SamplingPipeline { get; set; }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
using System.Collections.Generic;
|
||||
using LLama.Common;
|
||||
using LLama.Native;
|
||||
using LLama.Sampling;
|
||||
|
||||
namespace LLama.Abstractions
|
||||
{
|
||||
|
@ -108,5 +109,10 @@ namespace LLama.Abstractions
|
|||
/// Grammar to constrain possible tokens
|
||||
/// </summary>
|
||||
SafeLLamaGrammarHandle? Grammar { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b>
|
||||
/// </summary>
|
||||
ISamplingPipeline? SamplingPipeline { get; set; }
|
||||
}
|
||||
}
|
|
@ -2,6 +2,7 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using LLama.Native;
|
||||
using LLama.Sampling;
|
||||
|
||||
namespace LLama.Common
|
||||
{
|
||||
|
@ -76,6 +77,9 @@ namespace LLama.Common
|
|||
|
||||
/// <inheritdoc />
|
||||
public SafeLLamaGrammarHandle? Grammar { get; set; }
|
||||
|
||||
/// <inheritdoc />
|
||||
public ISamplingPipeline? SamplingPipeline { get; set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -10,6 +10,7 @@ using LLama.Common;
|
|||
using System.Runtime.InteropServices;
|
||||
using LLama.Extensions;
|
||||
using LLama.Abstractions;
|
||||
using LLama.Sampling;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace LLama
|
||||
|
@ -212,6 +213,17 @@ namespace LLama
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sample a single token from this context, using the given sampling pipeline
|
||||
/// </summary>
|
||||
/// <param name="pipeline">The pipeline to use to process the logits and to select a token</param>
|
||||
/// <param name="lastTokens">The tokens recently returned from the model</param>
|
||||
/// <returns>The selected token</returns>
|
||||
public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens)
|
||||
{
|
||||
return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Perform the sampling. Please don't use it unless you fully know what it does.
|
||||
/// </summary>
|
||||
|
|
|
@ -210,16 +210,24 @@ namespace LLama
|
|||
SaveSessionFile(_pathSession);
|
||||
}
|
||||
|
||||
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
|
||||
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
|
||||
llama_token id;
|
||||
if (inferenceParams.SamplingPipeline is not null)
|
||||
{
|
||||
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
|
||||
}
|
||||
else
|
||||
{
|
||||
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
|
||||
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
|
||||
|
||||
var mu = MirostatMu;
|
||||
var id = Context.Sample(
|
||||
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
|
||||
inferenceParams.MinP
|
||||
);
|
||||
MirostatMu = mu;
|
||||
var mu = MirostatMu;
|
||||
id = Context.Sample(
|
||||
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
|
||||
inferenceParams.MinP
|
||||
);
|
||||
MirostatMu = mu;
|
||||
}
|
||||
|
||||
_last_n_tokens.Enqueue(id);
|
||||
|
||||
|
|
|
@ -189,16 +189,24 @@ namespace LLama
|
|||
SaveSessionFile(_pathSession);
|
||||
}
|
||||
|
||||
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
|
||||
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
|
||||
llama_token id;
|
||||
if (inferenceParams.SamplingPipeline is not null)
|
||||
{
|
||||
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
|
||||
}
|
||||
else
|
||||
{
|
||||
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
|
||||
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
|
||||
|
||||
var mu = MirostatMu;
|
||||
var id = Context.Sample(
|
||||
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
|
||||
inferenceParams.MinP
|
||||
);
|
||||
MirostatMu = mu;
|
||||
var mu = MirostatMu;
|
||||
id = Context.Sample(
|
||||
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
|
||||
inferenceParams.MinP
|
||||
);
|
||||
MirostatMu = mu;
|
||||
}
|
||||
|
||||
_last_n_tokens.Enqueue(id);
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ using System.Runtime.CompilerServices;
|
|||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using LLama.Native;
|
||||
using LLama.Sampling;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace LLama
|
||||
|
@ -85,16 +86,24 @@ namespace LLama
|
|||
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
|
||||
for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
|
||||
{
|
||||
// Penalize the generated tokens by various penalties
|
||||
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
|
||||
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
|
||||
llama_token id;
|
||||
if (inferenceParams.SamplingPipeline is not null)
|
||||
{
|
||||
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Penalize the generated tokens by various penalties
|
||||
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
|
||||
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
|
||||
|
||||
// Sample a single token
|
||||
var id = Context.Sample(
|
||||
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
|
||||
inferenceParams.MinP
|
||||
);
|
||||
// Sample a single token
|
||||
id = Context.Sample(
|
||||
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
|
||||
inferenceParams.MinP
|
||||
);
|
||||
}
|
||||
|
||||
// Decode this token into text
|
||||
decoder.Add(id);
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.Runtime.InteropServices;
|
||||
using LLama.Native;
|
||||
using LLama.Sampling.Logits;
|
||||
using LLama.Sampling.Selection;
|
||||
|
@ -16,9 +18,9 @@ public interface ISamplingPipeline
|
|||
/// <summary>
|
||||
/// Sample a single token from the given logits
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="logits"></param>
|
||||
/// <param name="lastTokens"></param>
|
||||
/// <param name="ctx">The context being sampled from</param>
|
||||
/// <param name="logits">The logits produced by the model</param>
|
||||
/// <param name="lastTokens">A span of tokens recently returned by the model</param>
|
||||
/// <returns></returns>
|
||||
int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens);
|
||||
|
||||
|
@ -28,10 +30,43 @@ public interface ISamplingPipeline
|
|||
void Reset();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Extensions methods for ISamplingPipeline
|
||||
/// </summary>
|
||||
public static class ISamplingPipelineExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Sample a single token from the given logits
|
||||
/// </summary>
|
||||
/// <param name="pipeline"></param>
|
||||
/// <param name="ctx">The context being sampled from</param>
|
||||
/// <param name="logits">The logits produced by the model</param>
|
||||
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
|
||||
/// <returns></returns>
|
||||
public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<int> lastTokens)
|
||||
{
|
||||
#if NET5_0_OR_GREATER
|
||||
var span = CollectionsMarshal.AsSpan(lastTokens);
|
||||
return pipeline.Sample(ctx, logits, span);
|
||||
#else
|
||||
var copy = ArrayPool<int>.Shared.Rent(lastTokens.Count);
|
||||
try
|
||||
{
|
||||
lastTokens.CopyTo(copy);
|
||||
return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length));
|
||||
}
|
||||
finally
|
||||
{
|
||||
ArrayPool<int>.Shared.Return(copy);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Simple implementation of `ISamplingPipeline`, applies processors in order every time
|
||||
/// </summary>
|
||||
public sealed class BasicSamplingPipeline
|
||||
public sealed class ConfigurableSamplingPipeline
|
||||
: ISamplingPipeline
|
||||
{
|
||||
/// <summary>
|
||||
|
|
Loading…
Reference in New Issue