Change interface to support multiple images and add the capabitlity to render the image in the console

This commit is contained in:
SignalRT 2024-03-26 23:13:39 +01:00
parent 2d9a114f66
commit 43677c511c
6 changed files with 132 additions and 49 deletions

View File

@ -1,4 +1,7 @@
using LLama.Common;
using System.Text.RegularExpressions;
using LLama.Batched;
using LLama.Common;
using Spectre.Console;
namespace LLama.Examples.Examples
{
@ -8,15 +11,15 @@ namespace LLama.Examples.Examples
{
string multiModalProj = UserSettings.GetMMProjPath();
string modelPath = UserSettings.GetModelPath();
string imagePath = UserSettings.GetImagePath();
string modelImage = UserSettings.GetImagePath();
const int maxTokens = 1024;
var prompt = (await File.ReadAllTextAsync("Assets/vicuna-llava-v16.txt")).Trim();
var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n";
var parameters = new ModelParams(modelPath)
{
ContextSize = 4096,
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
@ -26,26 +29,93 @@ namespace LLama.Examples.Examples
var ex = new InteractiveExecutor(context, clipModel );
ex.ImagePath = imagePath;
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to {0} and the context size is {1}.", maxTokens, parameters.ContextSize );
Console.WriteLine("To send an image, enter its filename in curly braces, like this {c:/image.jpg}.");
var inferenceParams = new InferenceParams() { Temperature = 0.1f, AntiPrompts = new List<string> { "\nUSER:" }, MaxTokens = maxTokens };
do
{
// Evaluate if we have images
//
var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imageCount = imageMatches.Count();
var hasImages = imageCount > 0;
byte[][] imageBytes = null;
if (hasImages)
{
var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value);
try
{
imageBytes = imagePaths.Select(File.ReadAllBytes).ToArray();
}
catch (IOException exception)
{
Console.ForegroundColor = ConsoleColor.Red;
Console.Write(
$"Could not load your {(imageCount == 1 ? "image" : "images")}:");
Console.Write($"{exception.Message}");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Please try again.");
break;
}
int index = 0;
foreach (var path in imagePathsWithCurlyBraces)
{
// First image replace to tag <image, the rest of the images delete the tag
if (index++ == 0)
prompt = prompt.Replace(path, "<image>");
else
prompt = prompt.Replace(path, "");
}
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 1024 and the context size is 4096. ");
Console.ForegroundColor = ConsoleColor.White;
Console.WriteLine($"Here are the images, that are sent to the chat model in addition to your message.");
Console.WriteLine();
Console.Write(prompt);
var inferenceParams = new InferenceParams() { Temperature = 0.1f, AntiPrompts = new List<string> { "USER:" }, MaxTokens = 1024 };
while (true)
foreach (var consoleImage in imageBytes?.Select(bytes => new CanvasImage(bytes)))
{
consoleImage.MaxWidth = 50;
AnsiConsole.Write(consoleImage);
}
Console.WriteLine();
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine($"The images were scaled down for the console only, the model gets full versions.");
Console.WriteLine($"Write /exit or press Ctrl+c to return to main menu.");
Console.WriteLine();
// Initilize Images in executor
//
ex.ImagePaths = imagePaths.ToList();
}
Console.ForegroundColor = Color.White;
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
{
Console.Write(text);
}
Console.Write(" ");
Console.ForegroundColor = ConsoleColor.Green;
prompt = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.White;
Console.WriteLine();
// let the user finish with exit
//
if (prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
break;
}
while(true);
}
}
}

View File

@ -19,6 +19,7 @@
<PackageReference Include="Microsoft.SemanticKernel" Version="1.6.2" />
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Memory" Version="1.6.2-alpha" />
<PackageReference Include="Spectre.Console" Version="0.48.0" />
<PackageReference Include="Spectre.Console.ImageSharp" Version="0.48.0" />
</ItemGroup>
<ItemGroup>

View File

@ -25,9 +25,9 @@ namespace LLama.Abstractions
public LLavaWeights? ClipModel { get; }
/// <summary>
/// Image filename and path (jpeg images).
/// List of images: Image filename and path (jpeg images).
/// </summary>
public string? ImagePath { get; set; }
public List<string> ImagePaths { get; set; }
/// <summary>

View File

@ -71,7 +71,7 @@ namespace LLama
{
get
{
return ClipModel != null && ImagePath != null;
return ClipModel != null;
}
}
@ -79,7 +79,7 @@ namespace LLama
public LLavaWeights? ClipModel { get; }
/// <inheritdoc />
public string? ImagePath { get; set; }
public List<string> ImagePaths { get; set; }
/// <summary>
/// Current "mu" value for mirostat sampling
@ -95,6 +95,7 @@ namespace LLama
/// <param name="logger"></param>
protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
{
ImagePaths = new List<string>();
_logger = logger;
Context = context;
_pastTokensCount = 0;

View File

@ -24,7 +24,7 @@ namespace LLama
// LLava
private int _EmbedImagePosition = -1;
private SafeLlavaImageEmbedHandle _imageEmbedHandle = null;
private List<SafeLlavaImageEmbedHandle> _imageEmbedHandles = new List<SafeLlavaImageEmbedHandle>();
private bool _imageInPrompt = false;
/// <summary>
@ -125,30 +125,7 @@ namespace LLama
}
else
{
// If the prompt contains the tag <image> extract this.
_imageInPrompt = text.Contains("<image>");
if (_imageInPrompt)
{
if (!string.IsNullOrEmpty(ImagePath))
{
_imageEmbedHandle = SafeLlavaImageEmbedHandle.CreateFromFileName( ClipModel.NativeHandle, Context, ImagePath);
}
int imageIndex = text.IndexOf("<image>");
// Tokenize segment 1 (before <image> tag)
string preImagePrompt = text.Substring(0, imageIndex);
var segment1 = Context.Tokenize(preImagePrompt, true);
// Remember the position to add the image embeddings
_EmbedImagePosition = segment1.Length;
string postImagePrompt = text.Substring(imageIndex + 7);
var segment2 = Context.Tokenize(postImagePrompt, false);
_embed_inps.AddRange(segment1);
_embed_inps.AddRange(segment2);
}
else
{
_embed_inps = Context.Tokenize(text, true).ToList();
}
PreprocessLlava(text, args, true );
}
}
else
@ -157,6 +134,7 @@ namespace LLama
{
text += "\n";
}
var line_inp = Context.Tokenize(text, false);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
@ -165,6 +143,37 @@ namespace LLama
return Task.CompletedTask;
}
private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true )
{
int usedTokens = 0;
// If the prompt contains the tag <image> extract this.
_imageInPrompt = text.Contains("<image>");
if (_imageInPrompt)
{
foreach (var image in ImagePaths)
{
_imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromFileName( ClipModel.NativeHandle, Context, image ) );
}
int imageIndex = text.IndexOf("<image>");
// Tokenize segment 1 (before <image> tag)
string preImagePrompt = text.Substring(0, imageIndex);
var segment1 = Context.Tokenize(preImagePrompt, addBos );
// Remember the position to add the image embeddings
_EmbedImagePosition = segment1.Length;
string postImagePrompt = text.Substring(imageIndex + 7);
var segment2 = Context.Tokenize(postImagePrompt, false);
_embed_inps.AddRange(segment1);
_embed_inps.AddRange(segment2);
usedTokens += (segment1.Length + segment2.Length);
}
else
{
_embed_inps = Context.Tokenize(text, true).ToList();
}
return Task.CompletedTask;
}
/// <summary>
/// Return whether to break the generation.
/// </summary>
@ -216,18 +225,19 @@ namespace LLama
(DecodeResult, int) header, end, result;
if (IsMultiModal && _EmbedImagePosition > 0)
{
// Previous to Image
// Tokens previous to the images
header = Context.NativeHandle.Decode(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount);
if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1);
// Image
ClipModel.EvalImageEmbed(Context, _imageEmbedHandle, ref _pastTokensCount);
// Images
foreach( var image in _imageEmbedHandles )
ClipModel.EvalImageEmbed(Context, image, ref _pastTokensCount);
// Post-image
// Post-image Tokens
end = Context.NativeHandle.Decode(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount);
_EmbedImagePosition = -1;
_imageEmbedHandles.Clear();
}
else
{

View File

@ -27,8 +27,8 @@ namespace LLama
// LLava Section
public bool IsMultiModal => false;
public bool MultiModalProject { get; }
public LLavaWeights ClipModel { get; }
public string ImagePath { get; set; }
public LLavaWeights? ClipModel { get; }
public List<string> ImagePaths { get; set; }
/// <summary>
/// The context used by the executor when running the inference.
@ -43,6 +43,7 @@ namespace LLama
/// <param name="logger"></param>
public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
ImagePaths = new List<string>();
_weights = weights;
_params = @params;
_logger = logger;