Classifier Free Guidance (#536)

* Added a `Guidance` method to `LLamaTokenDataArray` which applies classifier free guidance

* Factored out a safer `llama_sample_apply_guidance` method based on spans

* Created a guided sampling demo using the batched executor

* fixed comment, "classifier free" not "context free"

* Rebased onto master and fixed breakage due to changes in `BaseSamplingPipeline`

* Asking user for guidance weight

* Progress bar in batched fork demo

* Improved fork example (using tree display)

* Added proper disposal of resources in batched examples

* Added some more comments in BatchedExecutorGuidance
This commit is contained in:
Martin Evans 2024-02-26 15:41:57 +00:00 committed by GitHub
parent 364259aabe
commit 7d84625a67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 246 additions and 30 deletions

View File

@ -26,6 +26,7 @@ public class ExampleRunner
{ "Semantic Kernel: Store", SemanticKernelMemory.Run },
{ "Batched Executor: Fork", BatchedExecutorFork.Run },
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
{ "Batched Executor: Guidance", BatchedExecutorGuidance.Run },
{ "Exit", () => { Environment.Exit(0); return Task.CompletedTask; } }
};

View File

@ -12,7 +12,7 @@ namespace LLama.Examples.Examples;
public class BatchedExecutorFork
{
private const int n_split = 16;
private const int n_len = 64;
private const int n_len = 72;
public static async Task Run()
{
@ -24,41 +24,51 @@ public class BatchedExecutorFork
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
// Create an executor that can evaluate a batch of conversations together
var executor = new BatchedExecutor(model, parameters);
using var executor = new BatchedExecutor(model, parameters);
// Print some info
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");
// Evaluate the initial prompt to create one conversation
var start = executor.Prompt(prompt);
using var start = executor.Prompt(prompt);
await executor.Infer();
// Create the root node of the tree
var root = new Node(start);
// Run inference loop
for (var i = 0; i < n_len; i++)
{
if (i != 0)
await executor.Infer();
await AnsiConsole
.Progress()
.StartAsync(async progress =>
{
var reporter = progress.AddTask("Running Inference (1)", maxValue: n_len);
// Occasionally fork all the active conversations
if (i != 0 && i % n_split == 0)
root.Split();
// Run inference loop
for (var i = 0; i < n_len; i++)
{
if (i != 0)
await executor.Infer();
// Sample all active conversations
root.Sample();
}
// Occasionally fork all the active conversations
if (i != 0 && i % n_split == 0)
root.Split();
Console.WriteLine($"{prompt}...");
root.Print(1);
// Sample all active conversations
root.Sample();
Console.WriteLine("Press any key to exit demo");
Console.ReadKey(true);
// Update progress bar
reporter.Increment(1);
reporter.Description($"Running Inference ({root.ActiveConversationCount})");
}
// Display results
var display = new Tree(prompt);
root.Display(display);
AnsiConsole.Write(display);
});
}
class Node
private class Node
{
private readonly StreamingTokenDecoder _decoder;
@ -116,19 +126,18 @@ public class BatchedExecutorFork
}
}
public void Print(int indendation)
public void Display<T>(T tree, int depth = 0)
where T : IHasTreeNodes
{
var colors = new[] { ConsoleColor.Red, ConsoleColor.Green, ConsoleColor.Blue, ConsoleColor.Yellow, ConsoleColor.White };
Console.ForegroundColor = colors[indendation % colors.Length];
var colors = new[] { "red", "green", "blue", "yellow", "white" };
var color = colors[depth % colors.Length];
var message = _decoder.Read().ReplaceLineEndings("");
var prefix = new string(' ', indendation * 3);
var suffix = _conversation == null ? "..." : "";
Console.WriteLine($"{prefix}...{message}{suffix}");
var n = tree.AddNode($"[{color}]{message}[/]");
_left?.Print(indendation + 2);
_right?.Print(indendation + 2);
_left?.Display(n, depth + 1);
_right?.Display(n, depth + 1);
}
}
}

View File

@ -0,0 +1,128 @@
using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;
using Spectre.Console;
namespace LLama.Examples.Examples;
/// <summary>
/// This demonstrates using a batch to generate two sequences and then using one
/// sequence as the negative guidance ("classifier free guidance") for the other.
/// </summary>
public class BatchedExecutorGuidance
{
private const int n_len = 32;
public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim();
var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim();
var weight = AnsiConsole.Ask("Guidance Weight (or ENTER for default):", 2.0f);
// Create an executor that can evaluate a batch of conversations together
using var executor = new BatchedExecutor(model, parameters);
// Print some info
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");
// Load the two prompts into two conversations
using var guided = executor.Prompt(positivePrompt);
using var guidance = executor.Prompt(negativePrompt);
// Run inference to evaluate prompts
await AnsiConsole
.Status()
.Spinner(Spinner.Known.Line)
.StartAsync("Evaluating Prompts...", _ => executor.Infer());
// Fork the "guided" conversation. We'll run this one without guidance for comparison
using var unguided = guided.Fork();
// Run inference loop
var unguidedSampler = new GuidedSampler(null, weight);
var unguidedDecoder = new StreamingTokenDecoder(executor.Context);
var guidedSampler = new GuidedSampler(guidance, weight);
var guidedDecoder = new StreamingTokenDecoder(executor.Context);
await AnsiConsole
.Progress()
.StartAsync(async progress =>
{
var reporter = progress.AddTask("Running Inference", maxValue: n_len);
for (var i = 0; i < n_len; i++)
{
if (i != 0)
await executor.Infer();
// Sample from the "unguided" conversation. This is just a conversation using the same prompt, without any
// guidance. This serves as a comparison to show the effect of guidance.
var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample(), Array.Empty<LLamaToken>());
unguidedDecoder.Add(u);
unguided.Prompt(u);
// Sample from the "guided" conversation. This sampler will internally use the "guidance" conversation
// to steer the conversation. See how this is done in GuidedSampler.ProcessLogits (bottom of this file).
var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample(), Array.Empty<LLamaToken>());
guidedDecoder.Add(g);
// Use this token to advance both guided _and_ guidance. Keeping them in sync (except for the initial prompt).
guided.Prompt(g);
guidance.Prompt(g);
// Early exit if we reach the natural end of the guided sentence
if (g == model.EndOfSentenceToken)
break;
// Update progress bar
reporter.Increment(1);
}
});
AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read()}[/]");
AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read()}[/]");
}
private class GuidedSampler(Conversation? guidance, float weight)
: BaseSamplingPipeline
{
public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
{
}
public override ISamplingPipeline Clone()
{
throw new NotSupportedException();
}
protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
if (guidance == null)
return logits;
var logitsCopy = logits.ToArray();
// Get the logits generated by the guidance sequences
var guidanceLogits = guidance.Sample();
// Use those logits to guide this sequence
NativeApi.llama_sample_apply_guidance(ctx, logitsCopy, guidanceLogits, weight);
return logitsCopy;
}
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
{
candidates.Temperature(ctx, 0.8f);
candidates.TopK(ctx, 25);
return candidates.SampleToken(ctx);
}
}
}

View File

@ -25,14 +25,14 @@ public class BatchedExecutorRewind
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
// Create an executor that can evaluate a batch of conversations together
var executor = new BatchedExecutor(model, parameters);
using var executor = new BatchedExecutor(model, parameters);
// Print some info
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");
// Evaluate the initial prompt to create one conversation
var conversation = executor.Prompt(prompt);
using var conversation = executor.Prompt(prompt);
// Create the start node wrapping the conversation
var node = new Node(executor.Context);

View File

@ -185,6 +185,56 @@ namespace LLama.Native
}
}
/// <summary>
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// </summary>
/// <param name="context"></param>
/// <param name="guidanceLogits">Logits extracted from a separate context from the same model.
/// Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.</param>
/// <param name="guidance">Guidance strength. 0 means no guidance, higher values applies stronger guidance</param>
public void Guidance(SafeLLamaContextHandle context, ReadOnlySpan<float> guidanceLogits, float guidance)
{
if (guidanceLogits.Length != data.Length)
throw new ArgumentException("Guidance logits count must equal vocabulary size", nameof(guidanceLogits));
if (guidance < 0)
throw new ArgumentOutOfRangeException(nameof(guidance), "Guidance strength must be greater than or equal to zero");
// this method accepts 0 (no guidance), higher means more. llama.cpp expects 1 (no guidance), higher means more
// Add one to move up to the llama.cpp baseline.
guidance += 1;
// We need logits array, which we don't have at this point.
// Copy them to a temporary array, apply guidance, then copy them back.
var logits = ArrayPool<float>.Shared.Rent(context.VocabCount);
try
{
// Copy logits into a temporary array
for (var i = 0; i < data.Length; i++)
{
ref var item = ref data.Span[i];
logits[(int)item.id] = item.logit;
}
// Apply guidance
NativeApi.llama_sample_apply_guidance(context, logits, guidanceLogits, guidance);
// Copy logits back into data array
for (var i = 0; i < data.Length; i++)
{
ref var item = ref data.Span[i];
item.logit = logits[(int)item.id];
}
// No longer sorted since we just mutated logits!
sorted = false;
}
finally
{
ArrayPool<float>.Shared.Return(logits);
}
}
/// <summary>
/// Sample with temperature.
/// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual

View File

@ -1,4 +1,5 @@
using System.Runtime.InteropServices;
using System;
using System.Runtime.InteropServices;
namespace LLama.Native
{
@ -23,6 +24,33 @@ namespace LLama.Native
float penalty_freq,
float penalty_present);
/// <summary>
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// </summary>
/// <param name="ctx"></param>
/// <param name="logits">Logits extracted from the original generation context.</param>
/// <param name="logits_guidance">Logits extracted from a separate context from the same model.
/// Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.</param>
/// <param name="scale">Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.</param>
public static void llama_sample_apply_guidance(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<float> logits_guidance, float scale)
{
if (logits == null)
throw new ArgumentNullException(nameof(logits));
if (logits_guidance == null)
throw new ArgumentNullException(nameof(logits_guidance));
if (logits.Length != ctx.VocabCount)
throw new ArgumentException("Logits count must have equal context vocab size", nameof(logits));
if (logits_guidance.Length != ctx.VocabCount)
throw new ArgumentException("Guidance logits count must have equal context vocab size", nameof(logits_guidance));
unsafe
{
fixed (float* logitsPtr = logits)
fixed (float* logitsGuidancePtr = logits_guidance)
llama_sample_apply_guidance(ctx, logitsPtr, logitsGuidancePtr, scale);
}
}
/// <summary>
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// </summary>