Classifier Free Guidance (#536)
* Added a `Guidance` method to `LLamaTokenDataArray` which applies classifier free guidance * Factored out a safer `llama_sample_apply_guidance` method based on spans * Created a guided sampling demo using the batched executor * fixed comment, "classifier free" not "context free" * Rebased onto master and fixed breakage due to changes in `BaseSamplingPipeline` * Asking user for guidance weight * Progress bar in batched fork demo * Improved fork example (using tree display) * Added proper disposal of resources in batched examples * Added some more comments in BatchedExecutorGuidance
This commit is contained in:
parent
364259aabe
commit
7d84625a67
|
@ -26,6 +26,7 @@ public class ExampleRunner
|
|||
{ "Semantic Kernel: Store", SemanticKernelMemory.Run },
|
||||
{ "Batched Executor: Fork", BatchedExecutorFork.Run },
|
||||
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
|
||||
{ "Batched Executor: Guidance", BatchedExecutorGuidance.Run },
|
||||
{ "Exit", () => { Environment.Exit(0); return Task.CompletedTask; } }
|
||||
};
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ namespace LLama.Examples.Examples;
|
|||
public class BatchedExecutorFork
|
||||
{
|
||||
private const int n_split = 16;
|
||||
private const int n_len = 64;
|
||||
private const int n_len = 72;
|
||||
|
||||
public static async Task Run()
|
||||
{
|
||||
|
@ -24,41 +24,51 @@ public class BatchedExecutorFork
|
|||
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
|
||||
|
||||
// Create an executor that can evaluate a batch of conversations together
|
||||
var executor = new BatchedExecutor(model, parameters);
|
||||
using var executor = new BatchedExecutor(model, parameters);
|
||||
|
||||
// Print some info
|
||||
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
|
||||
Console.WriteLine($"Created executor with model: {name}");
|
||||
|
||||
// Evaluate the initial prompt to create one conversation
|
||||
var start = executor.Prompt(prompt);
|
||||
using var start = executor.Prompt(prompt);
|
||||
await executor.Infer();
|
||||
|
||||
// Create the root node of the tree
|
||||
var root = new Node(start);
|
||||
|
||||
// Run inference loop
|
||||
for (var i = 0; i < n_len; i++)
|
||||
{
|
||||
if (i != 0)
|
||||
await executor.Infer();
|
||||
await AnsiConsole
|
||||
.Progress()
|
||||
.StartAsync(async progress =>
|
||||
{
|
||||
var reporter = progress.AddTask("Running Inference (1)", maxValue: n_len);
|
||||
|
||||
// Occasionally fork all the active conversations
|
||||
if (i != 0 && i % n_split == 0)
|
||||
root.Split();
|
||||
// Run inference loop
|
||||
for (var i = 0; i < n_len; i++)
|
||||
{
|
||||
if (i != 0)
|
||||
await executor.Infer();
|
||||
|
||||
// Sample all active conversations
|
||||
root.Sample();
|
||||
}
|
||||
// Occasionally fork all the active conversations
|
||||
if (i != 0 && i % n_split == 0)
|
||||
root.Split();
|
||||
|
||||
Console.WriteLine($"{prompt}...");
|
||||
root.Print(1);
|
||||
// Sample all active conversations
|
||||
root.Sample();
|
||||
|
||||
Console.WriteLine("Press any key to exit demo");
|
||||
Console.ReadKey(true);
|
||||
// Update progress bar
|
||||
reporter.Increment(1);
|
||||
reporter.Description($"Running Inference ({root.ActiveConversationCount})");
|
||||
}
|
||||
|
||||
// Display results
|
||||
var display = new Tree(prompt);
|
||||
root.Display(display);
|
||||
AnsiConsole.Write(display);
|
||||
});
|
||||
}
|
||||
|
||||
class Node
|
||||
private class Node
|
||||
{
|
||||
private readonly StreamingTokenDecoder _decoder;
|
||||
|
||||
|
@ -116,19 +126,18 @@ public class BatchedExecutorFork
|
|||
}
|
||||
}
|
||||
|
||||
public void Print(int indendation)
|
||||
public void Display<T>(T tree, int depth = 0)
|
||||
where T : IHasTreeNodes
|
||||
{
|
||||
var colors = new[] { ConsoleColor.Red, ConsoleColor.Green, ConsoleColor.Blue, ConsoleColor.Yellow, ConsoleColor.White };
|
||||
Console.ForegroundColor = colors[indendation % colors.Length];
|
||||
var colors = new[] { "red", "green", "blue", "yellow", "white" };
|
||||
var color = colors[depth % colors.Length];
|
||||
|
||||
var message = _decoder.Read().ReplaceLineEndings("");
|
||||
|
||||
var prefix = new string(' ', indendation * 3);
|
||||
var suffix = _conversation == null ? "..." : "";
|
||||
Console.WriteLine($"{prefix}...{message}{suffix}");
|
||||
var n = tree.AddNode($"[{color}]{message}[/]");
|
||||
|
||||
_left?.Print(indendation + 2);
|
||||
_right?.Print(indendation + 2);
|
||||
_left?.Display(n, depth + 1);
|
||||
_right?.Display(n, depth + 1);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
using LLama.Batched;
|
||||
using LLama.Common;
|
||||
using LLama.Native;
|
||||
using LLama.Sampling;
|
||||
using Spectre.Console;
|
||||
|
||||
namespace LLama.Examples.Examples;
|
||||
|
||||
/// <summary>
|
||||
/// This demonstrates using a batch to generate two sequences and then using one
|
||||
/// sequence as the negative guidance ("classifier free guidance") for the other.
|
||||
/// </summary>
|
||||
public class BatchedExecutorGuidance
|
||||
{
|
||||
private const int n_len = 32;
|
||||
|
||||
public static async Task Run()
|
||||
{
|
||||
string modelPath = UserSettings.GetModelPath();
|
||||
|
||||
var parameters = new ModelParams(modelPath);
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
|
||||
var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim();
|
||||
var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim();
|
||||
var weight = AnsiConsole.Ask("Guidance Weight (or ENTER for default):", 2.0f);
|
||||
|
||||
// Create an executor that can evaluate a batch of conversations together
|
||||
using var executor = new BatchedExecutor(model, parameters);
|
||||
|
||||
// Print some info
|
||||
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
|
||||
Console.WriteLine($"Created executor with model: {name}");
|
||||
|
||||
// Load the two prompts into two conversations
|
||||
using var guided = executor.Prompt(positivePrompt);
|
||||
using var guidance = executor.Prompt(negativePrompt);
|
||||
|
||||
// Run inference to evaluate prompts
|
||||
await AnsiConsole
|
||||
.Status()
|
||||
.Spinner(Spinner.Known.Line)
|
||||
.StartAsync("Evaluating Prompts...", _ => executor.Infer());
|
||||
|
||||
// Fork the "guided" conversation. We'll run this one without guidance for comparison
|
||||
using var unguided = guided.Fork();
|
||||
|
||||
// Run inference loop
|
||||
var unguidedSampler = new GuidedSampler(null, weight);
|
||||
var unguidedDecoder = new StreamingTokenDecoder(executor.Context);
|
||||
var guidedSampler = new GuidedSampler(guidance, weight);
|
||||
var guidedDecoder = new StreamingTokenDecoder(executor.Context);
|
||||
await AnsiConsole
|
||||
.Progress()
|
||||
.StartAsync(async progress =>
|
||||
{
|
||||
var reporter = progress.AddTask("Running Inference", maxValue: n_len);
|
||||
|
||||
for (var i = 0; i < n_len; i++)
|
||||
{
|
||||
if (i != 0)
|
||||
await executor.Infer();
|
||||
|
||||
// Sample from the "unguided" conversation. This is just a conversation using the same prompt, without any
|
||||
// guidance. This serves as a comparison to show the effect of guidance.
|
||||
var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample(), Array.Empty<LLamaToken>());
|
||||
unguidedDecoder.Add(u);
|
||||
unguided.Prompt(u);
|
||||
|
||||
// Sample from the "guided" conversation. This sampler will internally use the "guidance" conversation
|
||||
// to steer the conversation. See how this is done in GuidedSampler.ProcessLogits (bottom of this file).
|
||||
var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample(), Array.Empty<LLamaToken>());
|
||||
guidedDecoder.Add(g);
|
||||
|
||||
// Use this token to advance both guided _and_ guidance. Keeping them in sync (except for the initial prompt).
|
||||
guided.Prompt(g);
|
||||
guidance.Prompt(g);
|
||||
|
||||
// Early exit if we reach the natural end of the guided sentence
|
||||
if (g == model.EndOfSentenceToken)
|
||||
break;
|
||||
|
||||
// Update progress bar
|
||||
reporter.Increment(1);
|
||||
}
|
||||
});
|
||||
|
||||
AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read()}[/]");
|
||||
AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read()}[/]");
|
||||
}
|
||||
|
||||
private class GuidedSampler(Conversation? guidance, float weight)
|
||||
: BaseSamplingPipeline
|
||||
{
|
||||
public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
|
||||
{
|
||||
}
|
||||
|
||||
public override ISamplingPipeline Clone()
|
||||
{
|
||||
throw new NotSupportedException();
|
||||
}
|
||||
|
||||
protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
|
||||
{
|
||||
if (guidance == null)
|
||||
return logits;
|
||||
|
||||
var logitsCopy = logits.ToArray();
|
||||
|
||||
// Get the logits generated by the guidance sequences
|
||||
var guidanceLogits = guidance.Sample();
|
||||
|
||||
// Use those logits to guide this sequence
|
||||
NativeApi.llama_sample_apply_guidance(ctx, logitsCopy, guidanceLogits, weight);
|
||||
|
||||
return logitsCopy;
|
||||
}
|
||||
|
||||
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
|
||||
{
|
||||
candidates.Temperature(ctx, 0.8f);
|
||||
candidates.TopK(ctx, 25);
|
||||
|
||||
return candidates.SampleToken(ctx);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -25,14 +25,14 @@ public class BatchedExecutorRewind
|
|||
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
|
||||
|
||||
// Create an executor that can evaluate a batch of conversations together
|
||||
var executor = new BatchedExecutor(model, parameters);
|
||||
using var executor = new BatchedExecutor(model, parameters);
|
||||
|
||||
// Print some info
|
||||
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
|
||||
Console.WriteLine($"Created executor with model: {name}");
|
||||
|
||||
// Evaluate the initial prompt to create one conversation
|
||||
var conversation = executor.Prompt(prompt);
|
||||
using var conversation = executor.Prompt(prompt);
|
||||
|
||||
// Create the start node wrapping the conversation
|
||||
var node = new Node(executor.Context);
|
||||
|
|
|
@ -185,6 +185,56 @@ namespace LLama.Native
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
||||
/// </summary>
|
||||
/// <param name="context"></param>
|
||||
/// <param name="guidanceLogits">Logits extracted from a separate context from the same model.
|
||||
/// Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.</param>
|
||||
/// <param name="guidance">Guidance strength. 0 means no guidance, higher values applies stronger guidance</param>
|
||||
public void Guidance(SafeLLamaContextHandle context, ReadOnlySpan<float> guidanceLogits, float guidance)
|
||||
{
|
||||
if (guidanceLogits.Length != data.Length)
|
||||
throw new ArgumentException("Guidance logits count must equal vocabulary size", nameof(guidanceLogits));
|
||||
|
||||
if (guidance < 0)
|
||||
throw new ArgumentOutOfRangeException(nameof(guidance), "Guidance strength must be greater than or equal to zero");
|
||||
|
||||
// this method accepts 0 (no guidance), higher means more. llama.cpp expects 1 (no guidance), higher means more
|
||||
// Add one to move up to the llama.cpp baseline.
|
||||
guidance += 1;
|
||||
|
||||
// We need logits array, which we don't have at this point.
|
||||
// Copy them to a temporary array, apply guidance, then copy them back.
|
||||
var logits = ArrayPool<float>.Shared.Rent(context.VocabCount);
|
||||
try
|
||||
{
|
||||
// Copy logits into a temporary array
|
||||
for (var i = 0; i < data.Length; i++)
|
||||
{
|
||||
ref var item = ref data.Span[i];
|
||||
logits[(int)item.id] = item.logit;
|
||||
}
|
||||
|
||||
// Apply guidance
|
||||
NativeApi.llama_sample_apply_guidance(context, logits, guidanceLogits, guidance);
|
||||
|
||||
// Copy logits back into data array
|
||||
for (var i = 0; i < data.Length; i++)
|
||||
{
|
||||
ref var item = ref data.Span[i];
|
||||
item.logit = logits[(int)item.id];
|
||||
}
|
||||
|
||||
// No longer sorted since we just mutated logits!
|
||||
sorted = false;
|
||||
}
|
||||
finally
|
||||
{
|
||||
ArrayPool<float>.Shared.Return(logits);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sample with temperature.
|
||||
/// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
using System.Runtime.InteropServices;
|
||||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace LLama.Native
|
||||
{
|
||||
|
@ -23,6 +24,33 @@ namespace LLama.Native
|
|||
float penalty_freq,
|
||||
float penalty_present);
|
||||
|
||||
/// <summary>
|
||||
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="logits">Logits extracted from the original generation context.</param>
|
||||
/// <param name="logits_guidance">Logits extracted from a separate context from the same model.
|
||||
/// Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.</param>
|
||||
/// <param name="scale">Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.</param>
|
||||
public static void llama_sample_apply_guidance(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<float> logits_guidance, float scale)
|
||||
{
|
||||
if (logits == null)
|
||||
throw new ArgumentNullException(nameof(logits));
|
||||
if (logits_guidance == null)
|
||||
throw new ArgumentNullException(nameof(logits_guidance));
|
||||
if (logits.Length != ctx.VocabCount)
|
||||
throw new ArgumentException("Logits count must have equal context vocab size", nameof(logits));
|
||||
if (logits_guidance.Length != ctx.VocabCount)
|
||||
throw new ArgumentException("Guidance logits count must have equal context vocab size", nameof(logits_guidance));
|
||||
|
||||
unsafe
|
||||
{
|
||||
fixed (float* logitsPtr = logits)
|
||||
fixed (float* logitsGuidancePtr = logits_guidance)
|
||||
llama_sample_apply_guidance(ctx, logitsPtr, logitsGuidancePtr, scale);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
||||
/// </summary>
|
||||
|
|
Loading…
Reference in New Issue