Revert "Standardizing Image Data implementation"
This reverts commit b2423fe6e9
.
This commit is contained in:
parent
6bd269da60
commit
156d7bb463
|
@ -1,8 +1,7 @@
|
||||||
using System.Text.RegularExpressions;
|
using System.Text.RegularExpressions;
|
||||||
using LLama.Batched;
|
|
||||||
using LLama.Common;
|
using LLama.Common;
|
||||||
using Spectre.Console;
|
using Spectre.Console;
|
||||||
using LLama.Abstractions;
|
using LLama.Native;
|
||||||
|
|
||||||
namespace LLama.Examples.Examples
|
namespace LLama.Examples.Examples
|
||||||
{
|
{
|
||||||
|
@ -19,12 +18,8 @@ namespace LLama.Examples.Examples
|
||||||
|
|
||||||
var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n";
|
var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n";
|
||||||
|
|
||||||
var parameters = new ModelParams(modelPath)
|
var parameters = new ModelParams(modelPath);
|
||||||
{
|
|
||||||
ContextSize = 4096,
|
|
||||||
Seed = 1337,
|
|
||||||
GpuLayerCount = 10
|
|
||||||
};
|
|
||||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||||
using var context = model.CreateContext(parameters);
|
using var context = model.CreateContext(parameters);
|
||||||
|
|
||||||
|
@ -47,16 +42,16 @@ namespace LLama.Examples.Examples
|
||||||
var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
|
var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
|
||||||
var imageCount = imageMatches.Count();
|
var imageCount = imageMatches.Count();
|
||||||
var hasImages = imageCount > 0;
|
var hasImages = imageCount > 0;
|
||||||
byte[][] imageBytes = null;
|
|
||||||
|
|
||||||
if (hasImages)
|
if (hasImages)
|
||||||
{
|
{
|
||||||
var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
|
var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
|
||||||
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value);
|
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList();
|
||||||
|
|
||||||
|
List<byte[]> imageBytes;
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
imageBytes = imagePaths.Select(File.ReadAllBytes).ToArray();
|
imageBytes = imagePaths.Select(File.ReadAllBytes).ToList();
|
||||||
}
|
}
|
||||||
catch (IOException exception)
|
catch (IOException exception)
|
||||||
{
|
{
|
||||||
|
@ -69,15 +64,17 @@ namespace LLama.Examples.Examples
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Each prompt with images we clear cache
|
||||||
|
// When the prompt contains images we clear KV_CACHE to restart conversation
|
||||||
|
// See:
|
||||||
|
// https://github.com/ggerganov/llama.cpp/discussions/3620
|
||||||
|
ex.Context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 );
|
||||||
|
|
||||||
int index = 0;
|
int index = 0;
|
||||||
foreach (var path in imagePathsWithCurlyBraces)
|
foreach (var path in imagePathsWithCurlyBraces)
|
||||||
{
|
{
|
||||||
// First image replace to tag <image, the rest of the images delete the tag
|
// First image replace to tag <image, the rest of the images delete the tag
|
||||||
if (index++ == 0)
|
prompt = prompt.Replace(path, index++ == 0 ? "<image>" : "");
|
||||||
prompt = prompt.Replace(path, "<image>");
|
|
||||||
else
|
|
||||||
prompt = prompt.Replace(path, "");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -102,7 +99,7 @@ namespace LLama.Examples.Examples
|
||||||
//
|
//
|
||||||
foreach (var image in imagePaths)
|
foreach (var image in imagePaths)
|
||||||
{
|
{
|
||||||
ex.Images.Add(new ImageData(ImageData.DataType.ImagePath, image));
|
ex.Images.Add(await File.ReadAllBytesAsync(image));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,7 +115,7 @@ namespace LLama.Examples.Examples
|
||||||
|
|
||||||
// let the user finish with exit
|
// let the user finish with exit
|
||||||
//
|
//
|
||||||
if (prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
|
if (prompt != null && prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
|
||||||
break;
|
break;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,9 +25,9 @@ namespace LLama.Abstractions
|
||||||
public LLavaWeights? ClipModel { get; }
|
public LLavaWeights? ClipModel { get; }
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// List of images: Image filen path, uri or image byte array. See ImageData.
|
/// List of images: List of images in byte array format.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public List<ImageData> Images { get; }
|
public List<byte[]> Images { get; }
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Asynchronously infers a response from the model.
|
/// Asynchronously infers a response from the model.
|
||||||
|
@ -38,46 +38,4 @@ namespace LLama.Abstractions
|
||||||
/// <returns></returns>
|
/// <returns></returns>
|
||||||
IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
|
IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Holds image data
|
|
||||||
/// </summary>
|
|
||||||
public class ImageData
|
|
||||||
{
|
|
||||||
/// <summary>
|
|
||||||
/// constructor
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="type"></param>
|
|
||||||
/// <param name="data"></param>
|
|
||||||
public ImageData(DataType type, object data) { Type = type; Data = data; }
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// the possible types of image data
|
|
||||||
/// </summary>
|
|
||||||
public enum DataType
|
|
||||||
{
|
|
||||||
/// <summary>
|
|
||||||
/// file path
|
|
||||||
/// </summary>
|
|
||||||
ImagePath,
|
|
||||||
/// <summary>
|
|
||||||
/// byte array
|
|
||||||
/// </summary>
|
|
||||||
ImageBytes,
|
|
||||||
/// <summary>
|
|
||||||
/// uri
|
|
||||||
/// </summary>
|
|
||||||
ImageURL
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// the type of this image data
|
|
||||||
/// </summary>
|
|
||||||
public DataType Type { get; set; }
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// the image data (string, byte array or uri)
|
|
||||||
/// </summary>
|
|
||||||
public object? Data { get; set; }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,7 +34,7 @@ namespace LLama
|
||||||
public LLavaWeights? ClipModel { get; }
|
public LLavaWeights? ClipModel { get; }
|
||||||
|
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
public List<ImageData> Images { get; set; }
|
public List<byte[]> Images { get; set; }
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// The context used by the executor when running the inference.
|
/// The context used by the executor when running the inference.
|
||||||
|
@ -49,7 +49,7 @@ namespace LLama
|
||||||
/// <param name="logger"></param>
|
/// <param name="logger"></param>
|
||||||
public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
|
public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
|
||||||
{
|
{
|
||||||
Images = new List<ImageData>();
|
Images = new List<byte[]>();
|
||||||
_weights = weights;
|
_weights = weights;
|
||||||
_params = @params;
|
_params = @params;
|
||||||
_logger = logger;
|
_logger = logger;
|
||||||
|
@ -90,7 +90,7 @@ namespace LLama
|
||||||
lastTokens.Add(0);
|
lastTokens.Add(0);
|
||||||
|
|
||||||
// Tokenize the prompt
|
// Tokenize the prompt
|
||||||
var tokens = Context.Tokenize(prompt).ToList();
|
var tokens = Context.Tokenize(prompt, special: true).ToList();
|
||||||
lastTokens.AddRange(tokens);
|
lastTokens.AddRange(tokens);
|
||||||
|
|
||||||
// Evaluate the prompt, in chunks smaller than the max batch size
|
// Evaluate the prompt, in chunks smaller than the max batch size
|
||||||
|
@ -124,7 +124,7 @@ namespace LLama
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this is the EOS token
|
// Check if this is the EOS token
|
||||||
if (id == _weights.EndOfSentenceToken)
|
if (id == _weights.Tokens.EOS)
|
||||||
break;
|
break;
|
||||||
|
|
||||||
// Decode this token into text
|
// Decode this token into text
|
||||||
|
|
Loading…
Reference in New Issue