InferenceParams abstractions
This commit is contained in:
parent
2a04e31b7d
commit
bac9cba01a
|
@ -1,9 +1,99 @@
|
|||
using LLama.Common;
|
||||
using LLama.Abstractions;
|
||||
|
||||
namespace LLama.Web.Common
|
||||
{
|
||||
public class ParameterOptions : InferenceParams
|
||||
{
|
||||
public class ParameterOptions : IInferenceParams
|
||||
{
|
||||
public string Name { get; set; }
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// number of tokens to keep from initial prompt
|
||||
/// </summary>
|
||||
public int TokensKeep { get; set; } = 0;
|
||||
/// <summary>
|
||||
/// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response
|
||||
/// until it complete.
|
||||
/// </summary>
|
||||
public int MaxTokens { get; set; } = -1;
|
||||
/// <summary>
|
||||
/// logit bias for specific tokens
|
||||
/// </summary>
|
||||
public Dictionary<int, float>? LogitBias { get; set; } = null;
|
||||
|
||||
/// <summary>
|
||||
/// Sequences where the model will stop generating further tokens.
|
||||
/// </summary>
|
||||
public IEnumerable<string> AntiPrompts { get; set; } = Array.Empty<string>();
|
||||
/// <summary>
|
||||
/// path to file for saving/loading model eval state
|
||||
/// </summary>
|
||||
public string PathSession { get; set; } = string.Empty;
|
||||
/// <summary>
|
||||
/// string to suffix user inputs with
|
||||
/// </summary>
|
||||
public string InputSuffix { get; set; } = string.Empty;
|
||||
/// <summary>
|
||||
/// string to prefix user inputs with
|
||||
/// </summary>
|
||||
public string InputPrefix { get; set; } = string.Empty;
|
||||
/// <summary>
|
||||
/// 0 or lower to use vocab size
|
||||
/// </summary>
|
||||
public int TopK { get; set; } = 40;
|
||||
/// <summary>
|
||||
/// 1.0 = disabled
|
||||
/// </summary>
|
||||
public float TopP { get; set; } = 0.95f;
|
||||
/// <summary>
|
||||
/// 1.0 = disabled
|
||||
/// </summary>
|
||||
public float TfsZ { get; set; } = 1.0f;
|
||||
/// <summary>
|
||||
/// 1.0 = disabled
|
||||
/// </summary>
|
||||
public float TypicalP { get; set; } = 1.0f;
|
||||
/// <summary>
|
||||
/// 1.0 = disabled
|
||||
/// </summary>
|
||||
public float Temperature { get; set; } = 0.8f;
|
||||
/// <summary>
|
||||
/// 1.0 = disabled
|
||||
/// </summary>
|
||||
public float RepeatPenalty { get; set; } = 1.1f;
|
||||
/// <summary>
|
||||
/// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n)
|
||||
/// </summary>
|
||||
public int RepeatLastTokensCount { get; set; } = 64;
|
||||
/// <summary>
|
||||
/// frequency penalty coefficient
|
||||
/// 0.0 = disabled
|
||||
/// </summary>
|
||||
public float FrequencyPenalty { get; set; } = .0f;
|
||||
/// <summary>
|
||||
/// presence penalty coefficient
|
||||
/// 0.0 = disabled
|
||||
/// </summary>
|
||||
public float PresencePenalty { get; set; } = .0f;
|
||||
/// <summary>
|
||||
/// Mirostat uses tokens instead of words.
|
||||
/// algorithm described in the paper https://arxiv.org/abs/2007.14966.
|
||||
/// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
/// </summary>
|
||||
public MirostatType Mirostat { get; set; } = MirostatType.Disable;
|
||||
/// <summary>
|
||||
/// target entropy
|
||||
/// </summary>
|
||||
public float MirostatTau { get; set; } = 5.0f;
|
||||
/// <summary>
|
||||
/// learning rate
|
||||
/// </summary>
|
||||
public float MirostatEta { get; set; } = 0.1f;
|
||||
/// <summary>
|
||||
/// consider newlines as a repeatable token (penalize_nl)
|
||||
/// </summary>
|
||||
public bool PenalizeNL { get; set; } = true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
using System.Collections.Generic;
|
||||
using LLama.Common;
|
||||
|
||||
namespace LLama.Abstractions
|
||||
{
|
||||
/// <summary>
|
||||
/// The paramters used for inference.
|
||||
/// </summary>
|
||||
public interface IInferenceParams
|
||||
{
|
||||
/// <summary>
|
||||
/// number of tokens to keep from initial prompt
|
||||
/// </summary>
|
||||
public int TokensKeep { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response
|
||||
/// until it complete.
|
||||
/// </summary>
|
||||
public int MaxTokens { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// logit bias for specific tokens
|
||||
/// </summary>
|
||||
public Dictionary<int, float>? LogitBias { get; set; }
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Sequences where the model will stop generating further tokens.
|
||||
/// </summary>
|
||||
public IEnumerable<string> AntiPrompts { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// path to file for saving/loading model eval state
|
||||
/// </summary>
|
||||
public string PathSession { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// string to suffix user inputs with
|
||||
/// </summary>
|
||||
public string InputSuffix { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// string to prefix user inputs with
|
||||
/// </summary>
|
||||
public string InputPrefix { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 0 or lower to use vocab size
|
||||
/// </summary>
|
||||
public int TopK { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 1.0 = disabled
|
||||
/// </summary>
|
||||
public float TopP { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 1.0 = disabled
|
||||
/// </summary>
|
||||
public float TfsZ { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 1.0 = disabled
|
||||
/// </summary>
|
||||
public float TypicalP { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 1.0 = disabled
|
||||
/// </summary>
|
||||
public float Temperature { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 1.0 = disabled
|
||||
/// </summary>
|
||||
public float RepeatPenalty { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n)
|
||||
/// </summary>
|
||||
public int RepeatLastTokensCount { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// frequency penalty coefficient
|
||||
/// 0.0 = disabled
|
||||
/// </summary>
|
||||
public float FrequencyPenalty { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// presence penalty coefficient
|
||||
/// 0.0 = disabled
|
||||
/// </summary>
|
||||
public float PresencePenalty { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Mirostat uses tokens instead of words.
|
||||
/// algorithm described in the paper https://arxiv.org/abs/2007.14966.
|
||||
/// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
/// </summary>
|
||||
public MirostatType Mirostat { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// target entropy
|
||||
/// </summary>
|
||||
public float MirostatTau { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// learning rate
|
||||
/// </summary>
|
||||
public float MirostatEta { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// consider newlines as a repeatable token (penalize_nl)
|
||||
/// </summary>
|
||||
public bool PenalizeNL { get; set; }
|
||||
}
|
||||
}
|
|
@ -23,7 +23,7 @@ namespace LLama.Abstractions
|
|||
/// <param name="inferenceParams">Any additional parameters</param>
|
||||
/// <param name="token">A cancellation token.</param>
|
||||
/// <returns></returns>
|
||||
IEnumerable<string> Infer(string text, InferenceParams? inferenceParams = null, CancellationToken token = default);
|
||||
IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
|
||||
|
||||
/// <summary>
|
||||
/// Asynchronously infers a response from the model.
|
||||
|
@ -32,6 +32,6 @@ namespace LLama.Abstractions
|
|||
/// <param name="inferenceParams">Any additional parameters</param>
|
||||
/// <param name="token">A cancellation token.</param>
|
||||
/// <returns></returns>
|
||||
IAsyncEnumerable<string> InferAsync(string text, InferenceParams? inferenceParams = null, CancellationToken token = default);
|
||||
IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -138,7 +138,7 @@ namespace LLama
|
|||
/// <param name="inferenceParams"></param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public IEnumerable<string> Chat(ChatHistory history, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
|
||||
public IEnumerable<string> Chat(ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var prompt = HistoryTransform.HistoryToText(history);
|
||||
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
|
||||
|
@ -159,7 +159,7 @@ namespace LLama
|
|||
/// <param name="inferenceParams"></param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public IEnumerable<string> Chat(string prompt, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
|
||||
public IEnumerable<string> Chat(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
foreach(var inputTransform in InputTransformPipeline)
|
||||
{
|
||||
|
@ -182,7 +182,7 @@ namespace LLama
|
|||
/// <param name="inferenceParams"></param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var prompt = HistoryTransform.HistoryToText(history);
|
||||
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
|
||||
|
@ -202,7 +202,7 @@ namespace LLama
|
|||
/// <param name="inferenceParams"></param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public async IAsyncEnumerable<string> ChatAsync(string prompt, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
foreach (var inputTransform in InputTransformPipeline)
|
||||
{
|
||||
|
@ -218,13 +218,13 @@ namespace LLama
|
|||
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
|
||||
}
|
||||
|
||||
private IEnumerable<string> ChatInternal(string prompt, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
|
||||
private IEnumerable<string> ChatInternal(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var results = _executor.Infer(prompt, inferenceParams, cancellationToken);
|
||||
return OutputTransform.Transform(results);
|
||||
}
|
||||
|
||||
private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
|
||||
await foreach (var item in OutputTransform.TransformAsync(results))
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
using System;
|
||||
using LLama.Abstractions;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace LLama.Common
|
||||
|
@ -7,7 +8,7 @@ namespace LLama.Common
|
|||
/// <summary>
|
||||
/// The paramters used for inference.
|
||||
/// </summary>
|
||||
public class InferenceParams
|
||||
public class InferenceParams : IInferenceParams
|
||||
{
|
||||
/// <summary>
|
||||
/// number of tokens to keep from initial prompt
|
||||
|
|
|
@ -231,13 +231,13 @@ namespace LLama
|
|||
/// <param name="args"></param>
|
||||
/// <param name="extraOutputs"></param>
|
||||
/// <returns></returns>
|
||||
protected abstract bool PostProcess(InferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs);
|
||||
protected abstract bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs);
|
||||
/// <summary>
|
||||
/// The core inference logic.
|
||||
/// </summary>
|
||||
/// <param name="inferenceParams"></param>
|
||||
/// <param name="args"></param>
|
||||
protected abstract void InferInternal(InferenceParams inferenceParams, InferStateArgs args);
|
||||
protected abstract void InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
|
||||
/// <summary>
|
||||
/// Save the current state to a file.
|
||||
/// </summary>
|
||||
|
@ -267,7 +267,7 @@ namespace LLama
|
|||
/// <param name="inferenceParams"></param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public virtual IEnumerable<string> Infer(string text, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
|
||||
public virtual IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
cancellationToken.ThrowIfCancellationRequested();
|
||||
if (inferenceParams is null)
|
||||
|
@ -324,7 +324,7 @@ namespace LLama
|
|||
/// <param name="inferenceParams"></param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public virtual async IAsyncEnumerable<string> InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
foreach (var result in Infer(text, inferenceParams, cancellationToken))
|
||||
{
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
using LLama.Common;
|
||||
using LLama.Abstractions;
|
||||
using LLama.Common;
|
||||
using LLama.Native;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
@ -136,7 +137,7 @@ namespace LLama
|
|||
}
|
||||
}
|
||||
/// <inheritdoc />
|
||||
protected override bool PostProcess(InferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs)
|
||||
protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs)
|
||||
{
|
||||
extraOutputs = null;
|
||||
if (_embed_inps.Count <= _consumedTokensCount)
|
||||
|
@ -179,7 +180,7 @@ namespace LLama
|
|||
return false;
|
||||
}
|
||||
/// <inheritdoc />
|
||||
protected override void InferInternal(InferenceParams inferenceParams, InferStateArgs args)
|
||||
protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
|
||||
{
|
||||
if (_embeds.Count > 0)
|
||||
{
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
using LLama.Common;
|
||||
using LLama.Native;
|
||||
using LLama.Abstractions;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
|
@ -122,7 +123,7 @@ namespace LLama
|
|||
/// </summary>
|
||||
/// <param name="args"></param>
|
||||
/// <returns></returns>
|
||||
protected override bool PostProcess(InferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs)
|
||||
protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs)
|
||||
{
|
||||
extraOutputs = null;
|
||||
if (_embed_inps.Count <= _consumedTokensCount)
|
||||
|
@ -166,7 +167,7 @@ namespace LLama
|
|||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override void InferInternal(InferenceParams inferenceParams, InferStateArgs args)
|
||||
protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
|
||||
{
|
||||
if (_embeds.Count > 0)
|
||||
{
|
||||
|
|
|
@ -36,7 +36,7 @@ namespace LLama
|
|||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public IEnumerable<string> Infer(string text, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
|
||||
public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
cancellationToken.ThrowIfCancellationRequested();
|
||||
int n_past = 1;
|
||||
|
@ -123,7 +123,7 @@ namespace LLama
|
|||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async IAsyncEnumerable<string> InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
foreach (var result in Infer(text, inferenceParams, cancellationToken))
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue