Refactor conflicting object name SessionOptions

This commit is contained in:
sa_ddam213 2023-10-04 13:35:18 +13:00
parent 44f1b91c29
commit e2a17d6b6f
9 changed files with 43 additions and 30 deletions

View File

@ -0,0 +1,13 @@
namespace LLama.Web.Common
{
public interface ISessionConfig
{
string AntiPrompt { get; set; }
List<string> AntiPrompts { get; set; }
LLamaExecutorType ExecutorType { get; set; }
string Model { get; set; }
string OutputFilter { get; set; }
List<string> OutputFilters { get; set; }
string Prompt { get; set; }
}
}

View File

@ -1,6 +1,6 @@
namespace LLama.Web.Common
{
public class SessionOptions
public class SessionConfig : ISessionConfig
{
public string Model { get; set; }
public string Prompt { get; set; }

View File

@ -2,14 +2,14 @@
namespace LLama.Web
{
public static class Extensioms
public static class Extensions
{
/// <summary>
/// Combines the AntiPrompts list and AntiPrompt csv
/// </summary>
/// <param name="sessionConfig">The session configuration.</param>
/// <returns>Combined AntiPrompts with duplicates removed</returns>
public static List<string> GetAntiPrompts(this Common.SessionOptions sessionConfig)
public static List<string> GetAntiPrompts(this ISessionConfig sessionConfig)
{
return CombineCSV(sessionConfig.AntiPrompts, sessionConfig.AntiPrompt);
}
@ -19,7 +19,7 @@ namespace LLama.Web
/// </summary>
/// <param name="sessionConfig">The session configuration.</param>
/// <returns>Combined OutputFilters with duplicates removed</returns>
public static List<string> GetOutputFilters(this Common.SessionOptions sessionConfig)
public static List<string> GetOutputFilters(this ISessionConfig sessionConfig)
{
return CombineCSV(sessionConfig.OutputFilters, sessionConfig.OutputFilter);
}

View File

@ -37,7 +37,7 @@ namespace LLama.Web.Hubs
[HubMethodName("LoadModel")]
public async Task OnLoadModel(Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig)
public async Task OnLoadModel(ISessionConfig sessionConfig, InferenceOptions inferenceConfig)
{
_logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId);
await _modelSessionService.CloseAsync(Context.ConnectionId);

View File

@ -9,21 +9,21 @@ namespace LLama.Web.Models
private readonly LLamaModel _model;
private readonly LLamaContext _context;
private readonly ILLamaExecutor _executor;
private readonly Common.SessionOptions _sessionParams;
private readonly ISessionConfig _sessionConfig;
private readonly ITextStreamTransform _outputTransform;
private readonly InferenceOptions _defaultInferenceConfig;
private CancellationTokenSource _cancellationTokenSource;
public ModelSession(LLamaModel model, LLamaContext context, string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null)
public ModelSession(LLamaModel model, LLamaContext context, string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null)
{
_model = model;
_context = context;
_sessionId = sessionId;
_sessionParams = sessionOptions;
_sessionConfig = sessionConfig;
_defaultInferenceConfig = inferenceOptions ?? new InferenceOptions();
_outputTransform = CreateOutputFilter(_sessionParams);
_executor = CreateExecutor(_model, _context, _sessionParams);
_outputTransform = CreateOutputFilter();
_executor = CreateExecutor();
}
/// <summary>
@ -34,7 +34,7 @@ namespace LLama.Web.Models
/// <summary>
/// Gets the name of the model.
/// </summary>
public string ModelName => _sessionParams.Model;
public string ModelName => _sessionConfig.Model;
/// <summary>
/// Gets the context.
@ -44,7 +44,7 @@ namespace LLama.Web.Models
/// <summary>
/// Gets the session configuration.
/// </summary>
public Common.SessionOptions SessionConfig => _sessionParams;
public ISessionConfig SessionConfig => _sessionConfig;
/// <summary>
/// Gets the inference parameters.
@ -60,16 +60,16 @@ namespace LLama.Web.Models
/// <param name="cancellationToken">The cancellation token.</param>
internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
{
if (_sessionParams.ExecutorType == LLamaExecutorType.Stateless)
if (_sessionConfig.ExecutorType == LLamaExecutorType.Stateless)
return;
if (string.IsNullOrEmpty(_sessionParams.Prompt))
if (string.IsNullOrEmpty(_sessionConfig.Prompt))
return;
// Run Initial prompt
var inferenceParams = ConfigureInferenceParams(inferenceConfig);
_cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
await foreach (var _ in _executor.InferAsync(_sessionParams.Prompt, inferenceParams, _cancellationTokenSource.Token))
await foreach (var _ in _executor.InferAsync(_sessionConfig.Prompt, inferenceParams, _cancellationTokenSource.Token))
{
// We dont really need the response of the initial prompt, so exit on first token
break;
@ -114,13 +114,13 @@ namespace LLama.Web.Models
private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig)
{
var inferenceParams = inferenceConfig ?? _defaultInferenceConfig;
inferenceParams.AntiPrompts = _sessionParams.GetAntiPrompts();
inferenceParams.AntiPrompts = _sessionConfig.GetAntiPrompts();
return inferenceParams;
}
private ITextStreamTransform CreateOutputFilter(Common.SessionOptions sessionConfig)
private ITextStreamTransform CreateOutputFilter()
{
var outputFilters = sessionConfig.GetOutputFilters();
var outputFilters = _sessionConfig.GetOutputFilters();
if (outputFilters.Count > 0)
return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters);
@ -128,9 +128,9 @@ namespace LLama.Web.Models
}
private ILLamaExecutor CreateExecutor(LLamaModel model, LLamaContext context, Common.SessionOptions sessionConfig)
private ILLamaExecutor CreateExecutor()
{
return sessionConfig.ExecutorType switch
return _sessionConfig.ExecutorType switch
{
LLamaExecutorType.Interactive => new InteractiveExecutor(_context),
LLamaExecutorType.Instruct => new InstructExecutor(_context),

View File

@ -24,11 +24,11 @@
<div class="d-flex flex-column m-1">
<div class="d-flex flex-column mb-2">
<small>Model</small>
@Html.DropDownListFor(m => m.SessionOptions.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
@Html.DropDownListFor(m => m.SessionConfig.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
</div>
<div class="d-flex flex-column mb-2">
<small>Inference Type</small>
@Html.DropDownListFor(m => m.SessionOptions.ExecutorType, Html.GetEnumSelectList<LLamaExecutorType>(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
@Html.DropDownListFor(m => m.SessionConfig.ExecutorType, Html.GetEnumSelectList<LLamaExecutorType>(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
</div>
<nav>
<div class="nav nav-tabs" id="nav-tab" role="tablist">
@ -40,17 +40,17 @@
<div class="tab-pane fade show active" id="nav-prompt" role="tabpanel" aria-labelledby="nav-prompt-tab">
<div class="d-flex flex-column mb-2">
<small>Prompt</small>
@Html.TextAreaFor(m => Model.SessionOptions.Prompt, new { @type="text", @class = "form-control prompt-control", rows=8})
@Html.TextAreaFor(m => Model.SessionConfig.Prompt, new { @type="text", @class = "form-control prompt-control", rows=8})
</div>
<div class="d-flex flex-column mb-2">
<small>AntiPrompts</small>
@Html.TextBoxFor(m => Model.SessionOptions.AntiPrompt, new { @type="text", @class = "form-control prompt-control"})
@Html.TextBoxFor(m => Model.SessionConfig.AntiPrompt, new { @type="text", @class = "form-control prompt-control"})
</div>
<div class="d-flex flex-column mb-2">
<small>OutputFilter</small>
@Html.TextBoxFor(m => Model.SessionOptions.OutputFilter, new { @type="text", @class = "form-control prompt-control"})
@Html.TextBoxFor(m => Model.SessionConfig.OutputFilter, new { @type="text", @class = "form-control prompt-control"})
</div>
</div>
<div class="tab-pane fade" id="nav-params" role="tabpanel" aria-labelledby="nav-params-tab">

View File

@ -18,14 +18,14 @@ namespace LLama.Web.Pages
public LLamaOptions Options { get; set; }
[BindProperty]
public Common.SessionOptions SessionOptions { get; set; }
public ISessionConfig SessionConfig { get; set; }
[BindProperty]
public InferenceOptions InferenceOptions { get; set; }
public void OnGet()
{
SessionOptions = new Common.SessionOptions
SessionConfig = new SessionConfig
{
Prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.",
AntiPrompt = "User:",

View File

@ -24,7 +24,7 @@ namespace LLama.Web.Services
/// Creates a new ModelSession
/// </summary>
/// <param name="sessionId">The session identifier.</param>
/// <param name="sessionOptions">The session configuration.</param>
/// <param name="sessionConfig">The session configuration.</param>
/// <param name="inferenceOptions">The default inference configuration, will be used for all inference where no infer configuration is supplied.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
@ -33,7 +33,7 @@ namespace LLama.Web.Services
/// or
/// Failed to create model session
/// </exception>
Task<ModelSession> CreateAsync(string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default);
Task<ModelSession> CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default);
/// <summary>

View File

@ -65,7 +65,7 @@ namespace LLama.Web.Services
/// or
/// Failed to create model session
/// </exception>
public async Task<ModelSession> CreateAsync(string sessionId, Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
public async Task<ModelSession> CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
{
if (_modelSessions.TryGetValue(sessionId, out _))
throw new Exception($"Session with id {sessionId} already exists");