Revert "Standardizing Image Data implementation"

This reverts commit b2423fe6e9.
This commit is contained in:
Zoli Somogyi 2024-04-24 07:57:17 +02:00
parent 6bd269da60
commit 156d7bb463
3 changed files with 20 additions and 65 deletions

View File

@ -1,8 +1,7 @@
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
using LLama.Batched;
using LLama.Common; using LLama.Common;
using Spectre.Console; using Spectre.Console;
using LLama.Abstractions; using LLama.Native;
namespace LLama.Examples.Examples namespace LLama.Examples.Examples
{ {
@ -19,12 +18,8 @@ namespace LLama.Examples.Examples
var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n"; var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n";
var parameters = new ModelParams(modelPath) var parameters = new ModelParams(modelPath);
{
ContextSize = 4096,
Seed = 1337,
GpuLayerCount = 10
};
using var model = LLamaWeights.LoadFromFile(parameters); using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters); using var context = model.CreateContext(parameters);
@ -47,16 +42,16 @@ namespace LLama.Examples.Examples
var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imageCount = imageMatches.Count(); var imageCount = imageMatches.Count();
var hasImages = imageCount > 0; var hasImages = imageCount > 0;
byte[][] imageBytes = null;
if (hasImages) if (hasImages)
{ {
var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value); var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList();
List<byte[]> imageBytes;
try try
{ {
imageBytes = imagePaths.Select(File.ReadAllBytes).ToArray(); imageBytes = imagePaths.Select(File.ReadAllBytes).ToList();
} }
catch (IOException exception) catch (IOException exception)
{ {
@ -69,15 +64,17 @@ namespace LLama.Examples.Examples
break; break;
} }
// Each prompt with images we clear cache
// When the prompt contains images we clear KV_CACHE to restart conversation
// See:
// https://github.com/ggerganov/llama.cpp/discussions/3620
ex.Context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 );
int index = 0; int index = 0;
foreach (var path in imagePathsWithCurlyBraces) foreach (var path in imagePathsWithCurlyBraces)
{ {
// First image replace to tag <image, the rest of the images delete the tag // First image replace to tag <image, the rest of the images delete the tag
if (index++ == 0) prompt = prompt.Replace(path, index++ == 0 ? "<image>" : "");
prompt = prompt.Replace(path, "<image>");
else
prompt = prompt.Replace(path, "");
} }
@ -102,7 +99,7 @@ namespace LLama.Examples.Examples
// //
foreach (var image in imagePaths) foreach (var image in imagePaths)
{ {
ex.Images.Add(new ImageData(ImageData.DataType.ImagePath, image)); ex.Images.Add(await File.ReadAllBytesAsync(image));
} }
} }
@ -118,7 +115,7 @@ namespace LLama.Examples.Examples
// let the user finish with exit // let the user finish with exit
// //
if (prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase)) if (prompt != null && prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
break; break;
} }

View File

@ -25,9 +25,9 @@ namespace LLama.Abstractions
public LLavaWeights? ClipModel { get; } public LLavaWeights? ClipModel { get; }
/// <summary> /// <summary>
/// List of images: Image filen path, uri or image byte array. See ImageData. /// List of images: List of images in byte array format.
/// </summary> /// </summary>
public List<ImageData> Images { get; } public List<byte[]> Images { get; }
/// <summary> /// <summary>
/// Asynchronously infers a response from the model. /// Asynchronously infers a response from the model.
@ -38,46 +38,4 @@ namespace LLama.Abstractions
/// <returns></returns> /// <returns></returns>
IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default); IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
} }
/// <summary>
/// Holds image data
/// </summary>
public class ImageData
{
/// <summary>
/// constructor
/// </summary>
/// <param name="type"></param>
/// <param name="data"></param>
public ImageData(DataType type, object data) { Type = type; Data = data; }
/// <summary>
/// the possible types of image data
/// </summary>
public enum DataType
{
/// <summary>
/// file path
/// </summary>
ImagePath,
/// <summary>
/// byte array
/// </summary>
ImageBytes,
/// <summary>
/// uri
/// </summary>
ImageURL
}
/// <summary>
/// the type of this image data
/// </summary>
public DataType Type { get; set; }
/// <summary>
/// the image data (string, byte array or uri)
/// </summary>
public object? Data { get; set; }
}
} }

View File

@ -34,7 +34,7 @@ namespace LLama
public LLavaWeights? ClipModel { get; } public LLavaWeights? ClipModel { get; }
/// <inheritdoc /> /// <inheritdoc />
public List<ImageData> Images { get; set; } public List<byte[]> Images { get; set; }
/// <summary> /// <summary>
/// The context used by the executor when running the inference. /// The context used by the executor when running the inference.
@ -49,7 +49,7 @@ namespace LLama
/// <param name="logger"></param> /// <param name="logger"></param>
public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null) public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{ {
Images = new List<ImageData>(); Images = new List<byte[]>();
_weights = weights; _weights = weights;
_params = @params; _params = @params;
_logger = logger; _logger = logger;
@ -90,7 +90,7 @@ namespace LLama
lastTokens.Add(0); lastTokens.Add(0);
// Tokenize the prompt // Tokenize the prompt
var tokens = Context.Tokenize(prompt).ToList(); var tokens = Context.Tokenize(prompt, special: true).ToList();
lastTokens.AddRange(tokens); lastTokens.AddRange(tokens);
// Evaluate the prompt, in chunks smaller than the max batch size // Evaluate the prompt, in chunks smaller than the max batch size
@ -124,7 +124,7 @@ namespace LLama
} }
// Check if this is the EOS token // Check if this is the EOS token
if (id == _weights.EndOfSentenceToken) if (id == _weights.Tokens.EOS)
break; break;
// Decode this token into text // Decode this token into text