Merge branch 'master' into grammar_basics

This commit is contained in:
Martin Evans 2023-08-22 14:06:57 +01:00 committed by GitHub
commit 759ae26f36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 397 additions and 209 deletions

View File

@ -1,9 +1,5 @@
using LLama.Common;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace LLama.Examples.NewVersion
{
@ -12,15 +8,22 @@ namespace LLama.Examples.NewVersion
public static void Run()
{
Console.Write("Please input your model path: ");
string modelPath = Console.ReadLine();
var modelPath = Console.ReadLine();
var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim();
InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5)));
ChatSession session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8));
var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5);
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);
var session = new ChatSession(executor).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8));
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started. The role names won't be printed.");
Console.ForegroundColor = ConsoleColor.White;
// show the prompt
Console.Write(prompt);
while (true)
{
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))

View File

@ -1,9 +1,5 @@
using LLama.Common;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace LLama.Examples.NewVersion
{
@ -12,10 +8,15 @@ namespace LLama.Examples.NewVersion
public static void Run()
{
Console.Write("Please input your model path: ");
string modelPath = Console.ReadLine();
var modelPath = Console.ReadLine();
var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim();
InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5)));
ChatSession session = new ChatSession(ex); // The only change is to remove the transform for the output text stream.
var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5);
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);
var session = new ChatSession(executor);
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result.");

View File

@ -1,9 +1,4 @@
using LLama.Common;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace LLama.Examples.NewVersion
{
@ -12,7 +7,7 @@ namespace LLama.Examples.NewVersion
public static void Run()
{
Console.Write("Please input your model path: ");
string modelPath = Console.ReadLine();
var modelPath = Console.ReadLine();
var embedder = new LLamaEmbedder(new ModelParams(modelPath));
while (true)

View File

@ -1,9 +1,5 @@
using LLama.Common;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace LLama.Examples.NewVersion
{
@ -12,10 +8,13 @@ namespace LLama.Examples.NewVersion
public static void Run()
{
Console.Write("Please input your model path: ");
string modelPath = Console.ReadLine();
var modelPath = Console.ReadLine();
var prompt = File.ReadAllText("Assets/dan.txt").Trim();
InstructExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024)));
var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5);
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InstructExecutor(context);
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions. For example, you can input \"Write a story about a fox who want to " +
@ -26,7 +25,7 @@ namespace LLama.Examples.NewVersion
while (true)
{
foreach (var text in ex.Infer(prompt, inferenceParams))
foreach (var text in executor.Infer(prompt, inferenceParams))
{
Console.Write(text);
}

View File

@ -1,21 +1,20 @@
using LLama.Common;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace LLama.Examples.NewVersion
{
public class InteractiveModeExecute
{
public async static Task Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
string modelPath = Console.ReadLine();
var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim();
var modelPath = Console.ReadLine();
var prompt = (await File.ReadAllTextAsync("Assets/chat-with-bob.txt")).Trim();
InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 256)));
var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5);
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context);
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 128 and the context size is 256. (an example for small scale usage)");

View File

@ -1,10 +1,5 @@
using LLama.Common;
using LLama.OldVersion;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace LLama.Examples.NewVersion
{
@ -13,10 +8,15 @@ namespace LLama.Examples.NewVersion
public static void Run()
{
Console.Write("Please input your model path: ");
string modelPath = Console.ReadLine();
var modelPath = Console.ReadLine();
var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim();
InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5)));
ChatSession session = new ChatSession(ex); // The only change is to remove the transform for the output text stream.
var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5);
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context);
var session = new ChatSession(ex);
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result. Input \"save\" to save and reload the session.");

View File

@ -1,9 +1,5 @@
using LLama.Common;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace LLama.Examples.NewVersion
{
@ -12,10 +8,13 @@ namespace LLama.Examples.NewVersion
public static void Run()
{
Console.Write("Please input your model path: ");
string modelPath = Console.ReadLine();
var modelPath = Console.ReadLine();
var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim();
InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 256)));
var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5);
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context);
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 64 and the context size is 256. (an example for small scale usage)");
@ -47,9 +46,9 @@ namespace LLama.Examples.NewVersion
Console.WriteLine("All states saved!");
Console.ForegroundColor = ConsoleColor.White;
var model = ex.Context;
model.LoadState(modelStatePath);
ex = new InteractiveExecutor(model);
var ctx = ex.Context;
ctx.LoadState(modelStatePath);
ex = new InteractiveExecutor(ctx);
ex.LoadState(executorStatePath);
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Loaded state!");

View File

@ -1,11 +1,4 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace LLama.Examples.NewVersion
namespace LLama.Examples.NewVersion
{
public class QuantizeModel
{
@ -13,13 +6,16 @@ namespace LLama.Examples.NewVersion
{
Console.Write("Please input your original model path: ");
var inputPath = Console.ReadLine();
Console.Write("Please input your output model path: ");
var outputPath = Console.ReadLine();
Console.Write("Please input the quantize type (one of q4_0, q4_1, q5_0, q5_1, q8_0): ");
var quantizeType = Console.ReadLine();
if (LLamaQuantizer.Quantize(inputPath, outputPath, quantizeType))
{
Console.WriteLine("Quantization succeed!");
Console.WriteLine("Quantization succeeded!");
}
else
{

View File

@ -1,9 +1,4 @@
using LLama.Common;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace LLama.Examples.NewVersion
{
@ -12,9 +7,11 @@ namespace LLama.Examples.NewVersion
public static void Run()
{
Console.Write("Please input your model path: ");
string modelPath = Console.ReadLine();
var modelPath = Console.ReadLine();
StatelessExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 256)));
var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5);
using var model = LLamaWeights.LoadFromFile(parameters);
var ex = new StatelessExecutor(model, parameters);
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The executor has been enabled. In this example, the inference is an one-time job. That says, the previous input and response has " +
@ -29,10 +26,10 @@ namespace LLama.Examples.NewVersion
{
Console.Write("\nQuestion: ");
Console.ForegroundColor = ConsoleColor.Green;
string prompt = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.White;
var prompt = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: ");
prompt = $"Question: {prompt.Trim()} Answer: ";
prompt = $"Question: {prompt?.Trim()} Answer: ";
foreach (var text in ex.Infer(prompt, inferenceParams))
{
Console.Write(text);

View File

@ -20,9 +20,9 @@ namespace LLama.Examples.NewVersion
using var weights = LLamaWeights.LoadFromFile(@params);
// Create 2 contexts sharing the same weights
using var aliceCtx = weights.CreateContext(@params, Encoding.UTF8);
using var aliceCtx = weights.CreateContext(@params);
var alice = new InteractiveExecutor(aliceCtx);
using var bobCtx = weights.CreateContext(@params, Encoding.UTF8);
using var bobCtx = weights.CreateContext(@params);
var bob = new InteractiveExecutor(bobCtx);
// Initial alice prompt

View File

@ -1,10 +1,4 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace LLama.Examples.NewVersion
namespace LLama.Examples.NewVersion
{
public class NewVersionTestRunner
{

View File

@ -16,7 +16,7 @@ namespace LLama.Unittest
ContextSize = 768,
};
_weights = LLamaWeights.LoadFromFile(@params);
_context = _weights.CreateContext(@params, Encoding.UTF8);
_context = _weights.CreateContext(@params);
}
public void Dispose()

View File

@ -0,0 +1,70 @@
using LLama.Common;
using Xunit.Abstractions;
namespace LLama.Unittest
{
public class StatelessExecutorTest
: IDisposable
{
private readonly ITestOutputHelper _testOutputHelper;
private readonly LLamaWeights _weights;
private readonly ModelParams _params;
public StatelessExecutorTest(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
_params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")
{
ContextSize = 60,
Seed = 1754
};
_weights = LLamaWeights.LoadFromFile(_params);
}
public void Dispose()
{
_weights.Dispose();
}
[Fact]
public void Stateless()
{
var executor = new StatelessExecutor(_weights, _params);
const string question = "Question. what is a cat?\nAnswer: ";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };
var result1 = string.Join("", executor.Infer(question, @params));
var result2 = string.Join("", executor.Infer(question, @params));
_testOutputHelper.WriteLine(result1);
// Check that it produced the exact same result both times
Assert.Equal(result1, result2);
}
[Fact]
public void OutOfContext()
{
var executor = new StatelessExecutor(_weights, _params);
const string question = " Question. why is a cat the best pet?\nAnswer: ";
// The context size is set to 60. Generate more than that, forcing it to generate a coherent response
// with a modified context
var @params = new InferenceParams()
{
MaxTokens = 100,
TokensKeep = question.Length,
};
var result1 = string.Join("", executor.Infer(question, @params));
var result2 = string.Join("", executor.Infer(question, @params));
_testOutputHelper.WriteLine(result1);
// Check that it produced the exact same result both times
Assert.Equal(result1, result2);
}
}
}

View File

@ -2,7 +2,8 @@
namespace LLama.Web.Common
{
public class ModelOptions : IModelParams
public class ModelOptions
: IModelParams
{
public string Name { get; set; }
@ -111,5 +112,9 @@ namespace LLama.Web.Common
/// </summary>
public bool MulMatQ { get; set; }
}
/// <summary>
/// The encoding to use for models
/// </summary>
public string Encoding { get; set; } = "UTF-8";
}
}

View File

@ -1,6 +1,4 @@
using System;
namespace LLama.Abstractions
namespace LLama.Abstractions
{
public interface IModelParams
{
@ -119,5 +117,10 @@ namespace LLama.Abstractions
/// Use experimental mul_mat_q kernels
/// </summary>
bool MulMatQ { get; set; }
/// <summary>
/// The encoding to use for models
/// </summary>
string Encoding { get; set; }
}
}

View File

@ -1,5 +1,6 @@
using LLama.Abstractions;
using System;
using System.Text;
namespace LLama.Common
{
@ -111,34 +112,41 @@ namespace LLama.Common
/// </summary>
public bool MulMatQ { get; set; }
/// <summary>
///
/// </summary>
/// <param name="modelPath">The model path.</param>
/// <param name="contextSize">Model context size (n_ctx)</param>
/// <param name="gpuLayerCount">Number of layers to run in VRAM / GPU memory (n_gpu_layers)</param>
/// <param name="seed">Seed for the random number generator (seed)</param>
/// <param name="useFp16Memory">Whether to use f16 instead of f32 for memory kv (memory_f16)</param>
/// <param name="useMemorymap">Whether to use mmap for faster loads (use_mmap)</param>
/// <param name="useMemoryLock">Whether to use mlock to keep model in memory (use_mlock)</param>
/// <param name="perplexity">Thether to compute perplexity over the prompt (perplexity)</param>
/// <param name="loraAdapter">Lora adapter path (lora_adapter)</param>
/// <param name="loraBase">Base model path for the lora adapter (lora_base)</param>
/// <param name="threads">Number of threads (-1 = autodetect) (n_threads)</param>
/// <param name="batchSize">Batch size for prompt processing (must be >=32 to use BLAS) (n_batch)</param>
/// <param name="convertEosToNewLine">Whether to convert eos to newline during the inference.</param>
/// <param name="embeddingMode">Whether to use embedding mode. (embedding) Note that if this is set to true, The LLamaModel won't produce text response anymore.</param>
/// <param name="gqa">Grouped-Query Attention</param>
/// <param name="rmsNormEps">RMS Norm Epsilon</param>
/// <param name="rope_freq_base">RoPE base frequency.</param>
/// <param name="rope_freq_scale">RoPE frequency scaling factor</param>
/// <param name="muMatQ">Use experimental mul_mat_q kernels</param>
public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20,
int seed = 1337, bool useFp16Memory = true,
bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false,
string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512,
bool convertEosToNewLine = false, bool embeddingMode = false,
int gqa = 1, float rmsNormEps = 5e-6f, float rope_freq_base = 10000.0f, float rope_freq_scale = 1f, bool muMatQ = false)
/// <summary>
/// The encoding to use to convert text for the model
/// </summary>
public string Encoding { get; set; } = "UTF-8";
/// <summary>
///
/// </summary>
/// <param name="modelPath">The model path.</param>
/// <param name="contextSize">Model context size (n_ctx)</param>
/// <param name="gpuLayerCount">Number of layers to run in VRAM / GPU memory (n_gpu_layers)</param>
/// <param name="seed">Seed for the random number generator (seed)</param>
/// <param name="useFp16Memory">Whether to use f16 instead of f32 for memory kv (memory_f16)</param>
/// <param name="useMemorymap">Whether to use mmap for faster loads (use_mmap)</param>
/// <param name="useMemoryLock">Whether to use mlock to keep model in memory (use_mlock)</param>
/// <param name="perplexity">Thether to compute perplexity over the prompt (perplexity)</param>
/// <param name="loraAdapter">Lora adapter path (lora_adapter)</param>
/// <param name="loraBase">Base model path for the lora adapter (lora_base)</param>
/// <param name="threads">Number of threads (-1 = autodetect) (n_threads)</param>
/// <param name="batchSize">Batch size for prompt processing (must be >=32 to use BLAS) (n_batch)</param>
/// <param name="convertEosToNewLine">Whether to convert eos to newline during the inference.</param>
/// <param name="embeddingMode">Whether to use embedding mode. (embedding) Note that if this is set to true, The LLamaModel won't produce text response anymore.</param>
/// <param name="groupedQueryAttention">Grouped-Query Attention</param>
/// <param name="rmsNormEps">RMS Norm Epsilon</param>
/// <param name="ropeFreqBase">RoPE base frequency.</param>
/// <param name="ropeFreqScale">RoPE frequency scaling factor</param>
/// <param name="muMatQ">Use experimental mul_mat_q kernels</param>
/// <param name="encoding">The encoding to use to convert text for the model</param>
public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20,
int seed = 1337, bool useFp16Memory = true,
bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false,
string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512,
bool convertEosToNewLine = false, bool embeddingMode = false,
int groupedQueryAttention = 1, float rmsNormEps = 5e-6f, float ropeFreqBase = 10000.0f, float ropeFreqScale = 1f, bool muMatQ = false,
string encoding = "UTF-8")
{
ContextSize = contextSize;
GpuLayerCount = gpuLayerCount;
@ -154,11 +162,12 @@ namespace LLama.Common
BatchSize = batchSize;
ConvertEosToNewLine = convertEosToNewLine;
EmbeddingMode = embeddingMode;
GroupedQueryAttention = gqa;
GroupedQueryAttention = groupedQueryAttention;
RmsNormEpsilon = rmsNormEps;
RopeFrequencyBase = rope_freq_base;
RopeFrequencyScale = rope_freq_scale;
RopeFrequencyBase = ropeFreqBase;
RopeFrequencyScale = ropeFreqScale;
MulMatQ = muMatQ;
}
Encoding = encoding;
}
}
}

View File

@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
namespace LLama.Extensions
{
internal static class ListExtensions
{
public static void AddRangeSpan<T>(this List<T> list, ReadOnlySpan<T> span)
{
for (var i = 0; i < span.Length; i++)
list.Add(span[i]);
}
}
}

View File

@ -1,6 +1,7 @@
using LLama.Exceptions;
using LLama.Native;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Text;
@ -9,7 +10,6 @@ using System.IO.MemoryMappedFiles;
using LLama.Common;
using System.Runtime.InteropServices;
using LLama.Extensions;
using Microsoft.Win32.SafeHandles;
using LLama.Abstractions;
namespace LLama
@ -61,26 +61,25 @@ namespace LLama
///
/// </summary>
/// <param name="params">Model params.</param>
/// <param name="encoding">Encoding to deal with text input.</param>
/// <param name="logger">The logger.</param>
[Obsolete("Use the LLamaWeights.CreateContext instead")]
public LLamaContext(IModelParams @params, string encoding = "UTF-8", ILLamaLogger? logger = null)
public LLamaContext(IModelParams @params, ILLamaLogger? logger = null)
{
Params = @params;
_logger = logger;
_encoding = Encoding.GetEncoding(encoding);
_encoding = Encoding.GetEncoding(@params.Encoding);
_logger?.Log(nameof(LLamaContext), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info);
_ctx = Utils.InitLLamaContextFromModelParams(Params);
}
internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null)
internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, ILLamaLogger? logger = null)
{
Params = @params;
_logger = logger;
_encoding = encoding;
_encoding = Encoding.GetEncoding(@params.Encoding);
_ctx = nativeContext;
}
@ -89,10 +88,9 @@ namespace LLama
/// </summary>
/// <param name="model"></param>
/// <param name="params"></param>
/// <param name="encoding"></param>
/// <param name="logger"></param>
/// <exception cref="ObjectDisposedException"></exception>
public LLamaContext(LLamaWeights model, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null)
public LLamaContext(LLamaWeights model, IModelParams @params, ILLamaLogger? logger = null)
{
if (model.NativeHandle.IsClosed)
throw new ObjectDisposedException("Cannot create context, model weights have been disposed");
@ -100,7 +98,7 @@ namespace LLama
Params = @params;
_logger = logger;
_encoding = encoding;
_encoding = Encoding.GetEncoding(@params.Encoding);
using var pin = @params.ToLlamaContextParams(out var lparams);
_ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);
@ -115,7 +113,7 @@ namespace LLama
using var pin = Params.ToLlamaContextParams(out var lparams);
// Create a blank new context for the model
var ctx = new LLamaContext(SafeLLamaContextHandle.Create(NativeHandle.ModelHandle, lparams), Params, _encoding);
var ctx = new LLamaContext(SafeLLamaContextHandle.Create(NativeHandle.ModelHandle, lparams), Params);
// Copy across the state
using var state = GetState();
@ -398,6 +396,7 @@ namespace LLama
return candidates_p;
}
#region eval overloads
/// <summary>
///
/// </summary>
@ -405,18 +404,72 @@ namespace LLama
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount)
public int Eval(llama_token[] tokens, llama_token pastTokensCount)
{
int total = tokens.Length;
for(int i = 0; i < total; i += Params.BatchSize)
return Eval(tokens.AsSpan(), pastTokensCount);
}
/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public int Eval(List<llama_token> tokens, llama_token pastTokensCount)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens);
return Eval(span, pastTokensCount);
#else
// on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
// avoid the copying.
var rented = ArrayPool<llama_token>.Shared.Rent(tokens.Count);
try
{
int n_eval = total - i;
if(n_eval > Params.BatchSize)
tokens.CopyTo(rented, 0);
return Eval(rented, pastTokensCount);
}
finally
{
ArrayPool<llama_token>.Shared.Return(rented);
}
#endif
}
/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public int Eval(ReadOnlyMemory<llama_token> tokens, llama_token pastTokensCount)
{
return Eval(tokens.Span, pastTokensCount);
}
/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public int Eval(ReadOnlySpan<llama_token> tokens, llama_token pastTokensCount)
{
var total = tokens.Length;
for(var i = 0; i < total; i += Params.BatchSize)
{
var n_eval = total - i;
if (n_eval > Params.BatchSize)
{
n_eval = Params.BatchSize;
}
if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads))
if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount, Params.Threads))
{
_logger?.Log(nameof(LLamaContext), "Failed to eval.", ILLamaLogger.LogLevel.Error);
throw new RuntimeError("Failed to eval.");
@ -426,6 +479,7 @@ namespace LLama
}
return pastTokensCount;
}
#endregion
internal IEnumerable<string> GenerateResult(IEnumerable<llama_token> ids)
{
@ -433,6 +487,16 @@ namespace LLama
yield return _ctx.TokenToString(id, _encoding);
}
/// <summary>
/// Convert a token into a string
/// </summary>
/// <param name="token"></param>
/// <returns></returns>
public string TokenToString(llama_token token)
{
return NativeHandle.TokenToString(token, Encoding);
}
/// <inheritdoc />
public virtual void Dispose()
{

View File

@ -189,7 +189,7 @@ namespace LLama
}
TryReuseMathingPrefix();
_pastTokensCount = Context.Eval(_embeds.ToArray(), _pastTokensCount);
_pastTokensCount = Context.Eval(_embeds, _pastTokensCount);
if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
{

View File

@ -178,7 +178,7 @@ namespace LLama
}
TryReuseMathingPrefix();
_pastTokensCount = Context.Eval(_embeds.ToArray(), _pastTokensCount);
_pastTokensCount = Context.Eval(_embeds, _pastTokensCount);
if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
{

View File

@ -1,125 +1,159 @@
using LLama.Abstractions;
using LLama.Common;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
namespace LLama
{
using llama_token = Int32;
/// <summary>
/// This executor infer the input as one-time job. Previous inputs won't impact on the
/// response to current input.
/// </summary>
public class StatelessExecutor : ILLamaExecutor
public class StatelessExecutor
: ILLamaExecutor
{
private LLamaContext _context;
private LLamaContext.State _originalState;
private readonly LLamaWeights _weights;
private readonly IModelParams _params;
/// <summary>
/// The context used by the executor when running the inference.
/// </summary>
public LLamaContext Context => _context;
public LLamaContext Context { get; private set; }
/// <summary>
///
/// Create a new stateless executor which will use the given model
/// </summary>
/// <param name="context">The LLama model.</param>
/// <param name="weights"></param>
/// <param name="params"></param>
public StatelessExecutor(LLamaWeights weights, IModelParams @params)
{
_weights = weights;
_params = @params;
Context = _weights.CreateContext(_params);
Context.Dispose();
}
/// <summary>
/// Create a new stateless executor which will use the model used to create the given context
/// </summary>
/// <param name="context"></param>
[Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")]
public StatelessExecutor(LLamaContext context)
{
_context = context;
var tokens = context.Tokenize(" ", true).ToArray();
_context.NativeHandle.Eval(tokens.AsMemory(0, tokens.Length), 0, _context.Params.Threads);
_originalState = context.GetState();
_weights = new LLamaWeights(context.NativeHandle.ModelHandle, Encoding.GetEncoding(context.Params.Encoding));
_params = context.Params;
Context = _weights.CreateContext(_params);
Context.Dispose();
}
/// <inheritdoc />
public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
int n_past = 1;
if(inferenceParams is null)
{
inferenceParams = new InferenceParams();
}
List<llama_token> lastTokens = new(inferenceParams.RepeatLastTokensCount);
for(int i = 0; i < lastTokens.Count; i++)
{
lastTokens[i] = 0;
}
List<llama_token> tokens = _context.Tokenize(text, true).ToList();
int n_prompt_tokens = tokens.Count;
using var context = _weights.CreateContext(_params);
Context = context;
_context.NativeHandle.Eval(tokens.ToArray().AsMemory(0, n_prompt_tokens), n_past, _context.Params.Threads);
if (!Context.NativeHandle.IsClosed)
Context.Dispose();
Context = _weights.CreateContext(Context.Params);
if (inferenceParams != null)
{
if (inferenceParams.TokensKeep > Context.ContextSize)
throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");
}
cancellationToken.ThrowIfCancellationRequested();
var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty<string>();
var n_past = 1;
inferenceParams ??= new InferenceParams();
var lastTokens = new List<llama_token>(inferenceParams.RepeatLastTokensCount);
for (var i = 0; i < inferenceParams.RepeatLastTokensCount; i++)
lastTokens.Add(0);
var tokens = Context.Tokenize(text).ToList();
var n_prompt_tokens = tokens.Count;
Context.Eval(tokens, n_past);
lastTokens.AddRange(tokens);
n_past += n_prompt_tokens;
var mu = (float?)null;
int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(int i = 0; i < max_tokens; i++)
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(var i = 0; i < max_tokens; i++)
{
if (cancellationToken.IsCancellationRequested)
{
_context.LoadState(_originalState);
break;
}
var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _context.ContextSize : inferenceParams.RepeatLastTokensCount;
var tokenDataArray = _context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount;
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
var id = _context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);
lastTokens.Add(id);
string response = _context.NativeHandle.TokenToString(id, _context.Encoding);
var response = Context.TokenToString(id);
yield return response;
tokens.Clear();
tokens.Add(id);
if (inferenceParams.AntiPrompts is not null && inferenceParams.AntiPrompts.Count() > 0)
{
string last_output = "";
foreach (var token in lastTokens)
{
last_output += _context.NativeHandle.TokenToString(token, _context.Encoding);
}
bool should_break = false;
foreach (var antiprompt in inferenceParams.AntiPrompts)
{
if (last_output.EndsWith(antiprompt))
{
should_break = true;
break;
}
}
if (should_break)
{
break;
}
}
if (EndsWithAntiprompt(lastTokens, antiprompts))
break;
// when run out of context
if (n_past + tokens.Count > _context.ContextSize)
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433
if (n_past + tokens.Count > Context.ContextSize)
{
int n_left = n_past - inferenceParams.TokensKeep;
var n_left = n_past - inferenceParams.TokensKeep;
n_past = Math.Max(1, inferenceParams.TokensKeep);
// insert n_left/2 tokens at the start of embed from last_n_tokens
tokens.InsertRange(0, lastTokens.Take(lastTokens.Count - tokens.Count).Skip(_context.ContextSize - n_left / 2 - tokens.Count));
tokens.Clear();
tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2));
}
n_past = _context.Eval(tokens.ToArray(), n_past);
n_past = Context.Eval(tokens, n_past);
}
}
/// <summary>
/// Check if the given tokens list ends with any of the antiprompts
/// </summary>
/// <param name="tokens"></param>
/// <param name="antiprompts"></param>
/// <returns></returns>
private bool EndsWithAntiprompt(IReadOnlyList<llama_token> tokens, IReadOnlyList<string> antiprompts)
{
if (antiprompts.Count == 0 || tokens.Count == 0)
return false;
var builder = new StringBuilder();
foreach (var token in tokens)
builder.Append(Context.TokenToString(token));
var last_output = builder.ToString();
foreach (var antiprompt in antiprompts)
{
if (last_output.EndsWith(antiprompt))
return true;
}
_context.LoadState(_originalState);
return false;
}
/// <inheritdoc />

View File

@ -20,9 +20,15 @@ namespace LLama
/// <remarks>Be careful how you use this!</remarks>
public SafeLlamaModelHandle NativeHandle => _weights;
private LLamaWeights(SafeLlamaModelHandle weights)
/// <summary>
/// Encoding to use to convert text into bytes for the model
/// </summary>
public Encoding Encoding { get; }
internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding)
{
_weights = weights;
Encoding = encoding;
}
/// <summary>
@ -38,7 +44,7 @@ namespace LLama
if (!string.IsNullOrEmpty(@params.LoraAdapter))
weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);
return new LLamaWeights(weights);
return new LLamaWeights(weights, Encoding.GetEncoding(@params.Encoding));
}
/// <inheritdoc />
@ -51,11 +57,10 @@ namespace LLama
/// Create a llama_context using this model
/// </summary>
/// <param name="params"></param>
/// <param name="encoding"></param>
/// <returns></returns>
public LLamaContext CreateContext(IModelParams @params, Encoding encoding)
public LLamaContext CreateContext(IModelParams @params)
{
return new LLamaContext(this, @params, encoding);
return new LLamaContext(this, @params);
}
}
}

View File

@ -138,7 +138,6 @@ namespace LLama.Native
/// Rows: n_tokens<br />
/// Cols: n_vocab
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
public Span<float> GetLogits()
{
@ -179,12 +178,14 @@ namespace LLama.Native
/// <param name="n_past">the number of tokens to use from previous eval calls</param>
/// <param name="n_threads"></param>
/// <returns>Returns true on success</returns>
public bool Eval(Memory<int> tokens, int n_past, int n_threads)
public bool Eval(ReadOnlySpan<int> tokens, int n_past, int n_threads)
{
using var pin = tokens.Pin();
unsafe
{
return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0;
fixed (int* pinned = tokens)
{
return NativeApi.llama_eval_with_pointer(this, pinned, tokens.Length, n_past, n_threads) == 0;
}
}
}
}

View File

@ -37,7 +37,7 @@ namespace LLama
[Obsolete("Use SafeLLamaContextHandle Eval method instead")]
public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
{
var slice = tokens.AsMemory().Slice(startIndex, n_tokens);
var slice = tokens.AsSpan().Slice(startIndex, n_tokens);
return ctx.Eval(slice, n_past, n_threads) ? 0 : 1;
}