Change interface to support multiple images and add the capabitlity to render the image in the console
This commit is contained in:
parent
2d9a114f66
commit
43677c511c
|
@ -1,4 +1,7 @@
|
|||
using LLama.Common;
|
||||
using System.Text.RegularExpressions;
|
||||
using LLama.Batched;
|
||||
using LLama.Common;
|
||||
using Spectre.Console;
|
||||
|
||||
namespace LLama.Examples.Examples
|
||||
{
|
||||
|
@ -8,15 +11,15 @@ namespace LLama.Examples.Examples
|
|||
{
|
||||
string multiModalProj = UserSettings.GetMMProjPath();
|
||||
string modelPath = UserSettings.GetModelPath();
|
||||
string imagePath = UserSettings.GetImagePath();
|
||||
string modelImage = UserSettings.GetImagePath();
|
||||
const int maxTokens = 1024;
|
||||
|
||||
var prompt = (await File.ReadAllTextAsync("Assets/vicuna-llava-v16.txt")).Trim();
|
||||
var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n";
|
||||
|
||||
var parameters = new ModelParams(modelPath)
|
||||
{
|
||||
ContextSize = 4096,
|
||||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
|
@ -26,26 +29,93 @@ namespace LLama.Examples.Examples
|
|||
|
||||
var ex = new InteractiveExecutor(context, clipModel );
|
||||
|
||||
ex.ImagePath = imagePath;
|
||||
Console.ForegroundColor = ConsoleColor.Yellow;
|
||||
Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to {0} and the context size is {1}.", maxTokens, parameters.ContextSize );
|
||||
Console.WriteLine("To send an image, enter its filename in curly braces, like this {c:/image.jpg}.");
|
||||
|
||||
var inferenceParams = new InferenceParams() { Temperature = 0.1f, AntiPrompts = new List<string> { "\nUSER:" }, MaxTokens = maxTokens };
|
||||
|
||||
do
|
||||
{
|
||||
|
||||
// Evaluate if we have images
|
||||
//
|
||||
var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
|
||||
var imageCount = imageMatches.Count();
|
||||
var hasImages = imageCount > 0;
|
||||
byte[][] imageBytes = null;
|
||||
|
||||
if (hasImages)
|
||||
{
|
||||
var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
|
||||
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value);
|
||||
|
||||
try
|
||||
{
|
||||
imageBytes = imagePaths.Select(File.ReadAllBytes).ToArray();
|
||||
}
|
||||
catch (IOException exception)
|
||||
{
|
||||
Console.ForegroundColor = ConsoleColor.Red;
|
||||
Console.Write(
|
||||
$"Could not load your {(imageCount == 1 ? "image" : "images")}:");
|
||||
Console.Write($"{exception.Message}");
|
||||
Console.ForegroundColor = ConsoleColor.Yellow;
|
||||
Console.WriteLine("Please try again.");
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
int index = 0;
|
||||
foreach (var path in imagePathsWithCurlyBraces)
|
||||
{
|
||||
// First image replace to tag <image, the rest of the images delete the tag
|
||||
if (index++ == 0)
|
||||
prompt = prompt.Replace(path, "<image>");
|
||||
else
|
||||
prompt = prompt.Replace(path, "");
|
||||
}
|
||||
|
||||
|
||||
Console.ForegroundColor = ConsoleColor.Yellow;
|
||||
Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 1024 and the context size is 4096. ");
|
||||
Console.ForegroundColor = ConsoleColor.White;
|
||||
Console.WriteLine($"Here are the images, that are sent to the chat model in addition to your message.");
|
||||
Console.WriteLine();
|
||||
|
||||
Console.Write(prompt);
|
||||
|
||||
var inferenceParams = new InferenceParams() { Temperature = 0.1f, AntiPrompts = new List<string> { "USER:" }, MaxTokens = 1024 };
|
||||
|
||||
while (true)
|
||||
foreach (var consoleImage in imageBytes?.Select(bytes => new CanvasImage(bytes)))
|
||||
{
|
||||
consoleImage.MaxWidth = 50;
|
||||
AnsiConsole.Write(consoleImage);
|
||||
}
|
||||
|
||||
Console.WriteLine();
|
||||
Console.ForegroundColor = ConsoleColor.Yellow;
|
||||
Console.WriteLine($"The images were scaled down for the console only, the model gets full versions.");
|
||||
Console.WriteLine($"Write /exit or press Ctrl+c to return to main menu.");
|
||||
Console.WriteLine();
|
||||
|
||||
|
||||
// Initilize Images in executor
|
||||
//
|
||||
ex.ImagePaths = imagePaths.ToList();
|
||||
}
|
||||
|
||||
Console.ForegroundColor = Color.White;
|
||||
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
|
||||
{
|
||||
Console.Write(text);
|
||||
}
|
||||
Console.Write(" ");
|
||||
Console.ForegroundColor = ConsoleColor.Green;
|
||||
prompt = Console.ReadLine();
|
||||
Console.ForegroundColor = ConsoleColor.White;
|
||||
Console.WriteLine();
|
||||
|
||||
// let the user finish with exit
|
||||
//
|
||||
if (prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
|
||||
break;
|
||||
|
||||
}
|
||||
while(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
<PackageReference Include="Microsoft.SemanticKernel" Version="1.6.2" />
|
||||
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Memory" Version="1.6.2-alpha" />
|
||||
<PackageReference Include="Spectre.Console" Version="0.48.0" />
|
||||
<PackageReference Include="Spectre.Console.ImageSharp" Version="0.48.0" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
|
|
@ -25,9 +25,9 @@ namespace LLama.Abstractions
|
|||
public LLavaWeights? ClipModel { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Image filename and path (jpeg images).
|
||||
/// List of images: Image filename and path (jpeg images).
|
||||
/// </summary>
|
||||
public string? ImagePath { get; set; }
|
||||
public List<string> ImagePaths { get; set; }
|
||||
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -71,7 +71,7 @@ namespace LLama
|
|||
{
|
||||
get
|
||||
{
|
||||
return ClipModel != null && ImagePath != null;
|
||||
return ClipModel != null;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,7 +79,7 @@ namespace LLama
|
|||
public LLavaWeights? ClipModel { get; }
|
||||
|
||||
/// <inheritdoc />
|
||||
public string? ImagePath { get; set; }
|
||||
public List<string> ImagePaths { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Current "mu" value for mirostat sampling
|
||||
|
@ -95,6 +95,7 @@ namespace LLama
|
|||
/// <param name="logger"></param>
|
||||
protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
|
||||
{
|
||||
ImagePaths = new List<string>();
|
||||
_logger = logger;
|
||||
Context = context;
|
||||
_pastTokensCount = 0;
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace LLama
|
|||
|
||||
// LLava
|
||||
private int _EmbedImagePosition = -1;
|
||||
private SafeLlavaImageEmbedHandle _imageEmbedHandle = null;
|
||||
private List<SafeLlavaImageEmbedHandle> _imageEmbedHandles = new List<SafeLlavaImageEmbedHandle>();
|
||||
private bool _imageInPrompt = false;
|
||||
|
||||
/// <summary>
|
||||
|
@ -125,30 +125,7 @@ namespace LLama
|
|||
}
|
||||
else
|
||||
{
|
||||
// If the prompt contains the tag <image> extract this.
|
||||
_imageInPrompt = text.Contains("<image>");
|
||||
if (_imageInPrompt)
|
||||
{
|
||||
if (!string.IsNullOrEmpty(ImagePath))
|
||||
{
|
||||
_imageEmbedHandle = SafeLlavaImageEmbedHandle.CreateFromFileName( ClipModel.NativeHandle, Context, ImagePath);
|
||||
}
|
||||
|
||||
int imageIndex = text.IndexOf("<image>");
|
||||
// Tokenize segment 1 (before <image> tag)
|
||||
string preImagePrompt = text.Substring(0, imageIndex);
|
||||
var segment1 = Context.Tokenize(preImagePrompt, true);
|
||||
// Remember the position to add the image embeddings
|
||||
_EmbedImagePosition = segment1.Length;
|
||||
string postImagePrompt = text.Substring(imageIndex + 7);
|
||||
var segment2 = Context.Tokenize(postImagePrompt, false);
|
||||
_embed_inps.AddRange(segment1);
|
||||
_embed_inps.AddRange(segment2);
|
||||
}
|
||||
else
|
||||
{
|
||||
_embed_inps = Context.Tokenize(text, true).ToList();
|
||||
}
|
||||
PreprocessLlava(text, args, true );
|
||||
}
|
||||
}
|
||||
else
|
||||
|
@ -157,6 +134,7 @@ namespace LLama
|
|||
{
|
||||
text += "\n";
|
||||
}
|
||||
|
||||
var line_inp = Context.Tokenize(text, false);
|
||||
_embed_inps.AddRange(line_inp);
|
||||
args.RemainedTokens -= line_inp.Length;
|
||||
|
@ -165,6 +143,37 @@ namespace LLama
|
|||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true )
|
||||
{
|
||||
int usedTokens = 0;
|
||||
// If the prompt contains the tag <image> extract this.
|
||||
_imageInPrompt = text.Contains("<image>");
|
||||
if (_imageInPrompt)
|
||||
{
|
||||
foreach (var image in ImagePaths)
|
||||
{
|
||||
_imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromFileName( ClipModel.NativeHandle, Context, image ) );
|
||||
}
|
||||
|
||||
int imageIndex = text.IndexOf("<image>");
|
||||
// Tokenize segment 1 (before <image> tag)
|
||||
string preImagePrompt = text.Substring(0, imageIndex);
|
||||
var segment1 = Context.Tokenize(preImagePrompt, addBos );
|
||||
// Remember the position to add the image embeddings
|
||||
_EmbedImagePosition = segment1.Length;
|
||||
string postImagePrompt = text.Substring(imageIndex + 7);
|
||||
var segment2 = Context.Tokenize(postImagePrompt, false);
|
||||
_embed_inps.AddRange(segment1);
|
||||
_embed_inps.AddRange(segment2);
|
||||
usedTokens += (segment1.Length + segment2.Length);
|
||||
}
|
||||
else
|
||||
{
|
||||
_embed_inps = Context.Tokenize(text, true).ToList();
|
||||
}
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Return whether to break the generation.
|
||||
/// </summary>
|
||||
|
@ -216,18 +225,19 @@ namespace LLama
|
|||
(DecodeResult, int) header, end, result;
|
||||
if (IsMultiModal && _EmbedImagePosition > 0)
|
||||
{
|
||||
// Previous to Image
|
||||
// Tokens previous to the images
|
||||
header = Context.NativeHandle.Decode(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount);
|
||||
if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1);
|
||||
|
||||
// Image
|
||||
ClipModel.EvalImageEmbed(Context, _imageEmbedHandle, ref _pastTokensCount);
|
||||
// Images
|
||||
foreach( var image in _imageEmbedHandles )
|
||||
ClipModel.EvalImageEmbed(Context, image, ref _pastTokensCount);
|
||||
|
||||
// Post-image
|
||||
// Post-image Tokens
|
||||
end = Context.NativeHandle.Decode(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount);
|
||||
|
||||
_EmbedImagePosition = -1;
|
||||
|
||||
_imageEmbedHandles.Clear();
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
|
@ -27,8 +27,8 @@ namespace LLama
|
|||
// LLava Section
|
||||
public bool IsMultiModal => false;
|
||||
public bool MultiModalProject { get; }
|
||||
public LLavaWeights ClipModel { get; }
|
||||
public string ImagePath { get; set; }
|
||||
public LLavaWeights? ClipModel { get; }
|
||||
public List<string> ImagePaths { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// The context used by the executor when running the inference.
|
||||
|
@ -43,6 +43,7 @@ namespace LLama
|
|||
/// <param name="logger"></param>
|
||||
public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
|
||||
{
|
||||
ImagePaths = new List<string>();
|
||||
_weights = weights;
|
||||
_params = @params;
|
||||
_logger = logger;
|
||||
|
|
Loading…
Reference in New Issue