Correctly passing through mu value to mirostate instead of resetting it every time.

This commit is contained in:
Martin Evans 2023-07-30 00:15:52 +01:00
parent 0e438e6303
commit c64507cb41
5 changed files with 45 additions and 19 deletions

View File

@ -66,6 +66,12 @@ namespace LLama
/// The mode used by the executor.
/// </summary>
public LLamaModel Model => _model;
/// <summary>
/// Current "mu" value for mirostate sampling
/// </summary>
protected float MirostateMu { get; set; } = float.NaN;
/// <summary>
///
/// </summary>
@ -78,8 +84,6 @@ namespace LLama
_pastTokensCount = 0;
_consumedTokensCount = 0;
_n_session_consumed = 0;
_embeds = new();
_embed_inps = new();
_last_n_tokens = new FixedSizeQueue<llama_token>(_model.ContextSize).FillWith(0);
}
@ -359,24 +363,36 @@ namespace LLama
{
[JsonPropertyName("n_past")]
public int PastTokensCount { get; set; }
[JsonPropertyName("n_consumed")]
public int ConsumedTokensCount { get; set; }
[JsonPropertyName("n_session_consumed")]
public int ConsumedSessionCount { get; set; }
[JsonPropertyName("n_matching_session_tokens")]
public int MatchingSessionTokensCount { get; set; }
[JsonPropertyName("path_session")]
public string SessionFilePath { get; set; }
[JsonPropertyName("embd")]
public List<llama_token> Embeds { get; set; }
[JsonPropertyName("embd_inps")]
public List<llama_token> EmbedInps { get; set; }
[JsonPropertyName("session_tokens")]
public List<llama_token> SessionTokens { get; set; }
[JsonPropertyName("last_n_tokens")]
public llama_token[] LastTokens { get; set; }
[JsonPropertyName("last_tokens_maximum_count")]
public int LastTokensCapacity { get; set; }
[JsonPropertyName("mirostate_mu")]
public float MirostateMu { get; set; }
}
}
}

View File

@ -4,7 +4,6 @@ using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
@ -20,6 +19,7 @@ namespace LLama
string _instructionPrefix;
llama_token[] _inp_pfx;
llama_token[] _inp_sfx;
/// <summary>
///
/// </summary>
@ -51,7 +51,8 @@ namespace LLama
PastTokensCount = _pastTokensCount,
SessionFilePath = _pathSession,
SessionTokens = _session_tokens,
LastTokensCapacity = _last_n_tokens.Capacity
LastTokensCapacity = _last_n_tokens.Capacity,
MirostateMu = MirostateMu
};
return state;
}
@ -214,8 +215,12 @@ namespace LLama
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
var mu = MirostateMu;
var id = _model.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
);
MirostateMu = mu;
_last_n_tokens.Enqueue(id);

View File

@ -4,12 +4,8 @@ using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
namespace LLama
{
@ -21,6 +17,7 @@ namespace LLama
{
bool _is_prompt_run = true;
llama_token[] _llama_token_newline;
/// <summary>
///
/// </summary>
@ -46,7 +43,8 @@ namespace LLama
PastTokensCount = _pastTokensCount,
SessionFilePath = _pathSession,
SessionTokens = _session_tokens,
LastTokensCapacity = _last_n_tokens.Capacity
LastTokensCapacity = _last_n_tokens.Capacity,
MirostateMu = MirostateMu
};
return state;
}
@ -204,8 +202,12 @@ namespace LLama
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
var mu = MirostateMu;
var id = _model.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
);
MirostateMu = mu;
_last_n_tokens.Enqueue(id);

View File

@ -220,6 +220,7 @@ namespace LLama
/// Perform the sampling. Please don't use it unless you fully know what it does.
/// </summary>
/// <param name="candidates"></param>
/// <param name="mirostat_mu"></param>
/// <param name="temperature"></param>
/// <param name="mirostat"></param>
/// <param name="mirostatTau"></param>
@ -229,10 +230,10 @@ namespace LLama
/// <param name="tfsZ"></param>
/// <param name="typicalP"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
{
llama_token id = 0;
llama_token id;
if (temperature <= 0)
{
// Greedy sampling
@ -240,16 +241,17 @@ namespace LLama
}
else
{
if (float.IsNaN(mirostat_mu))
mirostat_mu = 2 * mirostatTau;
if (mirostat == MiroStateType.MiroState)
{
float mirostat_mu = 2.0f * mirostatTau;
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu);
}
else if (mirostat == MiroStateType.MiroState2)
{
float mirostat_mu = 2.0f * mirostatTau;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu);
}

View File

@ -57,6 +57,7 @@ namespace LLama
lastTokens.AddRange(tokens);
n_past += n_prompt_tokens;
var mu = float.NaN;
int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(int i = 0; i < max_tokens; i++)
{
@ -70,7 +71,7 @@ namespace LLama
var tokenDataArray = _model.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
var id = _model.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
lastTokens.Add(id);