Refactor conflicting object name SessionOptions
This commit is contained in:
parent
44f1b91c29
commit
e2a17d6b6f
|
@ -0,0 +1,13 @@
|
|||
namespace LLama.Web.Common
|
||||
{
|
||||
public interface ISessionConfig
|
||||
{
|
||||
string AntiPrompt { get; set; }
|
||||
List<string> AntiPrompts { get; set; }
|
||||
LLamaExecutorType ExecutorType { get; set; }
|
||||
string Model { get; set; }
|
||||
string OutputFilter { get; set; }
|
||||
List<string> OutputFilters { get; set; }
|
||||
string Prompt { get; set; }
|
||||
}
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
namespace LLama.Web.Common
|
||||
{
|
||||
public class SessionOptions
|
||||
public class SessionConfig : ISessionConfig
|
||||
{
|
||||
public string Model { get; set; }
|
||||
public string Prompt { get; set; }
|
|
@ -2,14 +2,14 @@
|
|||
|
||||
namespace LLama.Web
|
||||
{
|
||||
public static class Extensioms
|
||||
public static class Extensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Combines the AntiPrompts list and AntiPrompt csv
|
||||
/// </summary>
|
||||
/// <param name="sessionConfig">The session configuration.</param>
|
||||
/// <returns>Combined AntiPrompts with duplicates removed</returns>
|
||||
public static List<string> GetAntiPrompts(this Common.SessionOptions sessionConfig)
|
||||
public static List<string> GetAntiPrompts(this ISessionConfig sessionConfig)
|
||||
{
|
||||
return CombineCSV(sessionConfig.AntiPrompts, sessionConfig.AntiPrompt);
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ namespace LLama.Web
|
|||
/// </summary>
|
||||
/// <param name="sessionConfig">The session configuration.</param>
|
||||
/// <returns>Combined OutputFilters with duplicates removed</returns>
|
||||
public static List<string> GetOutputFilters(this Common.SessionOptions sessionConfig)
|
||||
public static List<string> GetOutputFilters(this ISessionConfig sessionConfig)
|
||||
{
|
||||
return CombineCSV(sessionConfig.OutputFilters, sessionConfig.OutputFilter);
|
||||
}
|
|
@ -37,7 +37,7 @@ namespace LLama.Web.Hubs
|
|||
|
||||
|
||||
[HubMethodName("LoadModel")]
|
||||
public async Task OnLoadModel(Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig)
|
||||
public async Task OnLoadModel(ISessionConfig sessionConfig, InferenceOptions inferenceConfig)
|
||||
{
|
||||
_logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId);
|
||||
await _modelSessionService.CloseAsync(Context.ConnectionId);
|
||||
|
|
|
@ -9,21 +9,21 @@ namespace LLama.Web.Models
|
|||
private readonly LLamaModel _model;
|
||||
private readonly LLamaContext _context;
|
||||
private readonly ILLamaExecutor _executor;
|
||||
private readonly Common.SessionOptions _sessionParams;
|
||||
private readonly ISessionConfig _sessionConfig;
|
||||
private readonly ITextStreamTransform _outputTransform;
|
||||
private readonly InferenceOptions _defaultInferenceConfig;
|
||||
|
||||
private CancellationTokenSource _cancellationTokenSource;
|
||||
|
||||
public ModelSession(LLamaModel model, LLamaContext context, string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null)
|
||||
public ModelSession(LLamaModel model, LLamaContext context, string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null)
|
||||
{
|
||||
_model = model;
|
||||
_context = context;
|
||||
_sessionId = sessionId;
|
||||
_sessionParams = sessionOptions;
|
||||
_sessionConfig = sessionConfig;
|
||||
_defaultInferenceConfig = inferenceOptions ?? new InferenceOptions();
|
||||
_outputTransform = CreateOutputFilter(_sessionParams);
|
||||
_executor = CreateExecutor(_model, _context, _sessionParams);
|
||||
_outputTransform = CreateOutputFilter();
|
||||
_executor = CreateExecutor();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -34,7 +34,7 @@ namespace LLama.Web.Models
|
|||
/// <summary>
|
||||
/// Gets the name of the model.
|
||||
/// </summary>
|
||||
public string ModelName => _sessionParams.Model;
|
||||
public string ModelName => _sessionConfig.Model;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the context.
|
||||
|
@ -44,7 +44,7 @@ namespace LLama.Web.Models
|
|||
/// <summary>
|
||||
/// Gets the session configuration.
|
||||
/// </summary>
|
||||
public Common.SessionOptions SessionConfig => _sessionParams;
|
||||
public ISessionConfig SessionConfig => _sessionConfig;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the inference parameters.
|
||||
|
@ -60,16 +60,16 @@ namespace LLama.Web.Models
|
|||
/// <param name="cancellationToken">The cancellation token.</param>
|
||||
internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (_sessionParams.ExecutorType == LLamaExecutorType.Stateless)
|
||||
if (_sessionConfig.ExecutorType == LLamaExecutorType.Stateless)
|
||||
return;
|
||||
|
||||
if (string.IsNullOrEmpty(_sessionParams.Prompt))
|
||||
if (string.IsNullOrEmpty(_sessionConfig.Prompt))
|
||||
return;
|
||||
|
||||
// Run Initial prompt
|
||||
var inferenceParams = ConfigureInferenceParams(inferenceConfig);
|
||||
_cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
await foreach (var _ in _executor.InferAsync(_sessionParams.Prompt, inferenceParams, _cancellationTokenSource.Token))
|
||||
await foreach (var _ in _executor.InferAsync(_sessionConfig.Prompt, inferenceParams, _cancellationTokenSource.Token))
|
||||
{
|
||||
// We dont really need the response of the initial prompt, so exit on first token
|
||||
break;
|
||||
|
@ -114,13 +114,13 @@ namespace LLama.Web.Models
|
|||
private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig)
|
||||
{
|
||||
var inferenceParams = inferenceConfig ?? _defaultInferenceConfig;
|
||||
inferenceParams.AntiPrompts = _sessionParams.GetAntiPrompts();
|
||||
inferenceParams.AntiPrompts = _sessionConfig.GetAntiPrompts();
|
||||
return inferenceParams;
|
||||
}
|
||||
|
||||
private ITextStreamTransform CreateOutputFilter(Common.SessionOptions sessionConfig)
|
||||
private ITextStreamTransform CreateOutputFilter()
|
||||
{
|
||||
var outputFilters = sessionConfig.GetOutputFilters();
|
||||
var outputFilters = _sessionConfig.GetOutputFilters();
|
||||
if (outputFilters.Count > 0)
|
||||
return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters);
|
||||
|
||||
|
@ -128,9 +128,9 @@ namespace LLama.Web.Models
|
|||
}
|
||||
|
||||
|
||||
private ILLamaExecutor CreateExecutor(LLamaModel model, LLamaContext context, Common.SessionOptions sessionConfig)
|
||||
private ILLamaExecutor CreateExecutor()
|
||||
{
|
||||
return sessionConfig.ExecutorType switch
|
||||
return _sessionConfig.ExecutorType switch
|
||||
{
|
||||
LLamaExecutorType.Interactive => new InteractiveExecutor(_context),
|
||||
LLamaExecutorType.Instruct => new InstructExecutor(_context),
|
||||
|
|
|
@ -24,11 +24,11 @@
|
|||
<div class="d-flex flex-column m-1">
|
||||
<div class="d-flex flex-column mb-2">
|
||||
<small>Model</small>
|
||||
@Html.DropDownListFor(m => m.SessionOptions.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
|
||||
@Html.DropDownListFor(m => m.SessionConfig.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
|
||||
</div>
|
||||
<div class="d-flex flex-column mb-2">
|
||||
<small>Inference Type</small>
|
||||
@Html.DropDownListFor(m => m.SessionOptions.ExecutorType, Html.GetEnumSelectList<LLamaExecutorType>(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
|
||||
@Html.DropDownListFor(m => m.SessionConfig.ExecutorType, Html.GetEnumSelectList<LLamaExecutorType>(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
|
||||
</div>
|
||||
<nav>
|
||||
<div class="nav nav-tabs" id="nav-tab" role="tablist">
|
||||
|
@ -40,17 +40,17 @@
|
|||
<div class="tab-pane fade show active" id="nav-prompt" role="tabpanel" aria-labelledby="nav-prompt-tab">
|
||||
<div class="d-flex flex-column mb-2">
|
||||
<small>Prompt</small>
|
||||
@Html.TextAreaFor(m => Model.SessionOptions.Prompt, new { @type="text", @class = "form-control prompt-control", rows=8})
|
||||
@Html.TextAreaFor(m => Model.SessionConfig.Prompt, new { @type="text", @class = "form-control prompt-control", rows=8})
|
||||
</div>
|
||||
|
||||
<div class="d-flex flex-column mb-2">
|
||||
<small>AntiPrompts</small>
|
||||
@Html.TextBoxFor(m => Model.SessionOptions.AntiPrompt, new { @type="text", @class = "form-control prompt-control"})
|
||||
@Html.TextBoxFor(m => Model.SessionConfig.AntiPrompt, new { @type="text", @class = "form-control prompt-control"})
|
||||
</div>
|
||||
|
||||
<div class="d-flex flex-column mb-2">
|
||||
<small>OutputFilter</small>
|
||||
@Html.TextBoxFor(m => Model.SessionOptions.OutputFilter, new { @type="text", @class = "form-control prompt-control"})
|
||||
@Html.TextBoxFor(m => Model.SessionConfig.OutputFilter, new { @type="text", @class = "form-control prompt-control"})
|
||||
</div>
|
||||
</div>
|
||||
<div class="tab-pane fade" id="nav-params" role="tabpanel" aria-labelledby="nav-params-tab">
|
||||
|
|
|
@ -18,14 +18,14 @@ namespace LLama.Web.Pages
|
|||
public LLamaOptions Options { get; set; }
|
||||
|
||||
[BindProperty]
|
||||
public Common.SessionOptions SessionOptions { get; set; }
|
||||
public ISessionConfig SessionConfig { get; set; }
|
||||
|
||||
[BindProperty]
|
||||
public InferenceOptions InferenceOptions { get; set; }
|
||||
|
||||
public void OnGet()
|
||||
{
|
||||
SessionOptions = new Common.SessionOptions
|
||||
SessionConfig = new SessionConfig
|
||||
{
|
||||
Prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.",
|
||||
AntiPrompt = "User:",
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace LLama.Web.Services
|
|||
/// Creates a new ModelSession
|
||||
/// </summary>
|
||||
/// <param name="sessionId">The session identifier.</param>
|
||||
/// <param name="sessionOptions">The session configuration.</param>
|
||||
/// <param name="sessionConfig">The session configuration.</param>
|
||||
/// <param name="inferenceOptions">The default inference configuration, will be used for all inference where no infer configuration is supplied.</param>
|
||||
/// <param name="cancellationToken">The cancellation token.</param>
|
||||
/// <returns></returns>
|
||||
|
@ -33,7 +33,7 @@ namespace LLama.Web.Services
|
|||
/// or
|
||||
/// Failed to create model session
|
||||
/// </exception>
|
||||
Task<ModelSession> CreateAsync(string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default);
|
||||
Task<ModelSession> CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default);
|
||||
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -65,7 +65,7 @@ namespace LLama.Web.Services
|
|||
/// or
|
||||
/// Failed to create model session
|
||||
/// </exception>
|
||||
public async Task<ModelSession> CreateAsync(string sessionId, Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
|
||||
public async Task<ModelSession> CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (_modelSessions.TryGetValue(sessionId, out _))
|
||||
throw new Exception($"Session with id {sessionId} already exists");
|
||||
|
|
Loading…
Reference in New Issue