82 lines
2.7 KiB
C#
82 lines
2.7 KiB
C#
using System.Diagnostics;
|
|
using LLama.Common;
|
|
using LLama.Sampling;
|
|
using Xunit.Abstractions;
|
|
|
|
namespace LLama.Unittest
|
|
{
|
|
public class StatelessExecutorTest
|
|
: IDisposable
|
|
{
|
|
private readonly ITestOutputHelper _testOutputHelper;
|
|
private readonly LLamaWeights _weights;
|
|
private readonly ModelParams _params;
|
|
|
|
public StatelessExecutorTest(ITestOutputHelper testOutputHelper)
|
|
{
|
|
_testOutputHelper = testOutputHelper;
|
|
_params = new ModelParams(Constants.ModelPath)
|
|
{
|
|
ContextSize = 60,
|
|
Seed = 1754,
|
|
};
|
|
_weights = LLamaWeights.LoadFromFile(_params);
|
|
}
|
|
|
|
public void Dispose()
|
|
{
|
|
_weights.Dispose();
|
|
}
|
|
|
|
[Fact]
|
|
public async Task Stateless()
|
|
{
|
|
// Create a custom pipeline that mimics the default pipeline
|
|
var pipeline = new DefaultSamplingPipeline();
|
|
|
|
var executor = new StatelessExecutor(_weights, _params);
|
|
|
|
const string question = "Question. what is a cat?\nAnswer: ";
|
|
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };
|
|
|
|
var timer = new Stopwatch();
|
|
timer.Start();
|
|
|
|
var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
|
|
var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
|
|
|
|
timer.Stop();
|
|
_testOutputHelper.WriteLine($"{timer.ElapsedMilliseconds}ms");
|
|
|
|
_testOutputHelper.WriteLine(result1);
|
|
_testOutputHelper.WriteLine(result2);
|
|
|
|
// Check that it produced the exact same result both times
|
|
Assert.Equal(result1, result2);
|
|
}
|
|
|
|
[Fact(Skip = "Very very slow in CI")]
|
|
public async Task OutOfContext()
|
|
{
|
|
var executor = new StatelessExecutor(_weights, _params);
|
|
|
|
const string question = " Question. cats or dogs?\nAnswer: ";
|
|
|
|
// The context size is set to 60. Generate more than that, forcing it to generate a coherent response
|
|
// with a modified context
|
|
var @params = new InferenceParams()
|
|
{
|
|
MaxTokens = 65,
|
|
TokensKeep = question.Length,
|
|
};
|
|
|
|
var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
|
|
var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
|
|
|
|
_testOutputHelper.WriteLine(result1);
|
|
|
|
// Check that it produced the exact same result both times
|
|
Assert.Equal(result1, result2);
|
|
}
|
|
}
|
|
} |