LLamaSharp/LLama.Web/Models/ModelSession.cs

143 lines
5.2 KiB
C#

using LLama.Abstractions;
using LLama.Web.Common;
namespace LLama.Web.Models
{
public class ModelSession
{
private readonly string _sessionId;
private readonly LLamaModel _model;
private readonly LLamaContext _context;
private readonly ILLamaExecutor _executor;
private readonly ISessionConfig _sessionConfig;
private readonly ITextStreamTransform _outputTransform;
private readonly InferenceOptions _defaultInferenceConfig;
private CancellationTokenSource _cancellationTokenSource;
public ModelSession(LLamaModel model, LLamaContext context, string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null)
{
_model = model;
_context = context;
_sessionId = sessionId;
_sessionConfig = sessionConfig;
_defaultInferenceConfig = inferenceOptions ?? new InferenceOptions();
_outputTransform = CreateOutputFilter();
_executor = CreateExecutor();
}
/// <summary>
/// Gets the session identifier.
/// </summary>
public string SessionId => _sessionId;
/// <summary>
/// Gets the name of the model.
/// </summary>
public string ModelName => _sessionConfig.Model;
/// <summary>
/// Gets the context.
/// </summary>
public LLamaContext Context => _context;
/// <summary>
/// Gets the session configuration.
/// </summary>
public ISessionConfig SessionConfig => _sessionConfig;
/// <summary>
/// Gets the inference parameters.
/// </summary>
public InferenceOptions InferenceParams => _defaultInferenceConfig;
/// <summary>
/// Initializes the prompt.
/// </summary>
/// <param name="inferenceConfig">The inference configuration.</param>
/// <param name="cancellationToken">The cancellation token.</param>
internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
{
if (_sessionConfig.ExecutorType == LLamaExecutorType.Stateless)
return;
if (string.IsNullOrEmpty(_sessionConfig.Prompt))
return;
// Run Initial prompt
var inferenceParams = ConfigureInferenceParams(inferenceConfig);
_cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
await foreach (var _ in _executor.InferAsync(_sessionConfig.Prompt, inferenceParams, _cancellationTokenSource.Token))
{
// We dont really need the response of the initial prompt, so exit on first token
break;
};
}
/// <summary>
/// Runs inference on the model context
/// </summary>
/// <param name="message">The message.</param>
/// <param name="inferenceConfig">The inference configuration.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
internal IAsyncEnumerable<string> InferAsync(string message, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
{
var inferenceParams = ConfigureInferenceParams(inferenceConfig);
_cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
var inferenceStream = _executor.InferAsync(message, inferenceParams, _cancellationTokenSource.Token);
if (_outputTransform is not null)
return _outputTransform.TransformAsync(inferenceStream);
return inferenceStream;
}
public void CancelInfer()
{
_cancellationTokenSource?.Cancel();
}
public bool IsInferCanceled()
{
return _cancellationTokenSource.IsCancellationRequested;
}
/// <summary>
/// Configures the inference parameters.
/// </summary>
/// <param name="inferenceConfig">The inference configuration.</param>
private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig)
{
var inferenceParams = inferenceConfig ?? _defaultInferenceConfig;
inferenceParams.AntiPrompts = _sessionConfig.GetAntiPrompts();
return inferenceParams;
}
private ITextStreamTransform CreateOutputFilter()
{
var outputFilters = _sessionConfig.GetOutputFilters();
if (outputFilters.Count > 0)
return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters);
return null;
}
private ILLamaExecutor CreateExecutor()
{
return _sessionConfig.ExecutorType switch
{
LLamaExecutorType.Interactive => new InteractiveExecutor(_context),
LLamaExecutorType.Instruct => new InstructExecutor(_context),
LLamaExecutorType.Stateless => new StatelessExecutor(_model.LLamaWeights, _context.Params),
_ => default
};
}
}
}