Correctly passing through mu value to mirostate instead of resetting it every time.
This commit is contained in:
parent
0e438e6303
commit
c64507cb41
|
@ -66,6 +66,12 @@ namespace LLama
|
|||
/// The mode used by the executor.
|
||||
/// </summary>
|
||||
public LLamaModel Model => _model;
|
||||
|
||||
/// <summary>
|
||||
/// Current "mu" value for mirostate sampling
|
||||
/// </summary>
|
||||
protected float MirostateMu { get; set; } = float.NaN;
|
||||
|
||||
/// <summary>
|
||||
///
|
||||
/// </summary>
|
||||
|
@ -78,8 +84,6 @@ namespace LLama
|
|||
_pastTokensCount = 0;
|
||||
_consumedTokensCount = 0;
|
||||
_n_session_consumed = 0;
|
||||
_embeds = new();
|
||||
_embed_inps = new();
|
||||
_last_n_tokens = new FixedSizeQueue<llama_token>(_model.ContextSize).FillWith(0);
|
||||
}
|
||||
|
||||
|
@ -359,24 +363,36 @@ namespace LLama
|
|||
{
|
||||
[JsonPropertyName("n_past")]
|
||||
public int PastTokensCount { get; set; }
|
||||
|
||||
[JsonPropertyName("n_consumed")]
|
||||
public int ConsumedTokensCount { get; set; }
|
||||
|
||||
[JsonPropertyName("n_session_consumed")]
|
||||
public int ConsumedSessionCount { get; set; }
|
||||
|
||||
[JsonPropertyName("n_matching_session_tokens")]
|
||||
public int MatchingSessionTokensCount { get; set; }
|
||||
|
||||
[JsonPropertyName("path_session")]
|
||||
public string SessionFilePath { get; set; }
|
||||
|
||||
[JsonPropertyName("embd")]
|
||||
public List<llama_token> Embeds { get; set; }
|
||||
|
||||
[JsonPropertyName("embd_inps")]
|
||||
public List<llama_token> EmbedInps { get; set; }
|
||||
|
||||
[JsonPropertyName("session_tokens")]
|
||||
public List<llama_token> SessionTokens { get; set; }
|
||||
|
||||
[JsonPropertyName("last_n_tokens")]
|
||||
public llama_token[] LastTokens { get; set; }
|
||||
|
||||
[JsonPropertyName("last_tokens_maximum_count")]
|
||||
public int LastTokensCapacity { get; set; }
|
||||
|
||||
[JsonPropertyName("mirostate_mu")]
|
||||
public float MirostateMu { get; set; }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ using System;
|
|||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
|
||||
|
@ -20,6 +19,7 @@ namespace LLama
|
|||
string _instructionPrefix;
|
||||
llama_token[] _inp_pfx;
|
||||
llama_token[] _inp_sfx;
|
||||
|
||||
/// <summary>
|
||||
///
|
||||
/// </summary>
|
||||
|
@ -51,7 +51,8 @@ namespace LLama
|
|||
PastTokensCount = _pastTokensCount,
|
||||
SessionFilePath = _pathSession,
|
||||
SessionTokens = _session_tokens,
|
||||
LastTokensCapacity = _last_n_tokens.Capacity
|
||||
LastTokensCapacity = _last_n_tokens.Capacity,
|
||||
MirostateMu = MirostateMu
|
||||
};
|
||||
return state;
|
||||
}
|
||||
|
@ -214,8 +215,12 @@ namespace LLama
|
|||
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
|
||||
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
|
||||
|
||||
var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
|
||||
var mu = MirostateMu;
|
||||
var id = _model.Sample(
|
||||
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
|
||||
);
|
||||
MirostateMu = mu;
|
||||
|
||||
_last_n_tokens.Enqueue(id);
|
||||
|
||||
|
|
|
@ -4,12 +4,8 @@ using System;
|
|||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace LLama
|
||||
{
|
||||
|
@ -21,6 +17,7 @@ namespace LLama
|
|||
{
|
||||
bool _is_prompt_run = true;
|
||||
llama_token[] _llama_token_newline;
|
||||
|
||||
/// <summary>
|
||||
///
|
||||
/// </summary>
|
||||
|
@ -46,7 +43,8 @@ namespace LLama
|
|||
PastTokensCount = _pastTokensCount,
|
||||
SessionFilePath = _pathSession,
|
||||
SessionTokens = _session_tokens,
|
||||
LastTokensCapacity = _last_n_tokens.Capacity
|
||||
LastTokensCapacity = _last_n_tokens.Capacity,
|
||||
MirostateMu = MirostateMu
|
||||
};
|
||||
return state;
|
||||
}
|
||||
|
@ -204,8 +202,12 @@ namespace LLama
|
|||
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
|
||||
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
|
||||
|
||||
var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
|
||||
var mu = MirostateMu;
|
||||
var id = _model.Sample(
|
||||
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
|
||||
);
|
||||
MirostateMu = mu;
|
||||
|
||||
_last_n_tokens.Enqueue(id);
|
||||
|
||||
|
|
|
@ -220,6 +220,7 @@ namespace LLama
|
|||
/// Perform the sampling. Please don't use it unless you fully know what it does.
|
||||
/// </summary>
|
||||
/// <param name="candidates"></param>
|
||||
/// <param name="mirostat_mu"></param>
|
||||
/// <param name="temperature"></param>
|
||||
/// <param name="mirostat"></param>
|
||||
/// <param name="mirostatTau"></param>
|
||||
|
@ -229,10 +230,10 @@ namespace LLama
|
|||
/// <param name="tfsZ"></param>
|
||||
/// <param name="typicalP"></param>
|
||||
/// <returns></returns>
|
||||
public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable,
|
||||
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
|
||||
public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable,
|
||||
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
|
||||
{
|
||||
llama_token id = 0;
|
||||
llama_token id;
|
||||
if (temperature <= 0)
|
||||
{
|
||||
// Greedy sampling
|
||||
|
@ -240,16 +241,17 @@ namespace LLama
|
|||
}
|
||||
else
|
||||
{
|
||||
if (float.IsNaN(mirostat_mu))
|
||||
mirostat_mu = 2 * mirostatTau;
|
||||
|
||||
if (mirostat == MiroStateType.MiroState)
|
||||
{
|
||||
float mirostat_mu = 2.0f * mirostatTau;
|
||||
const int mirostat_m = 100;
|
||||
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
|
||||
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu);
|
||||
}
|
||||
else if (mirostat == MiroStateType.MiroState2)
|
||||
{
|
||||
float mirostat_mu = 2.0f * mirostatTau;
|
||||
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
|
||||
id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu);
|
||||
}
|
||||
|
|
|
@ -57,6 +57,7 @@ namespace LLama
|
|||
lastTokens.AddRange(tokens);
|
||||
n_past += n_prompt_tokens;
|
||||
|
||||
var mu = float.NaN;
|
||||
int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
|
||||
for(int i = 0; i < max_tokens; i++)
|
||||
{
|
||||
|
@ -70,7 +71,7 @@ namespace LLama
|
|||
var tokenDataArray = _model.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
|
||||
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
|
||||
|
||||
var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
||||
var id = _model.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
||||
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
|
||||
|
||||
lastTokens.Add(id);
|
||||
|
|
Loading…
Reference in New Issue