Fixed spelling of "mirostat" instead of "mirostate"

This commit is contained in:
Martin Evans 2023-07-27 23:11:25 +01:00
parent 4d7d4f2bfe
commit 36735f7908
2 changed files with 22 additions and 8 deletions

View File

@ -1,6 +1,5 @@
using System;
using System.Collections.Generic;
using System.Text;
namespace LLama.Common
{
@ -83,7 +82,7 @@ namespace LLama.Common
/// algorithm described in the paper https://arxiv.org/abs/2007.14966.
/// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
/// </summary>
public MiroStateType Mirostat { get; set; } = MiroStateType.Disable;
public MirostatType Mirostat { get; set; } = MirostatType.Disable;
/// <summary>
/// target entropy
/// </summary>
@ -98,10 +97,25 @@ namespace LLama.Common
public bool PenalizeNL { get; set; } = true;
}
public enum MiroStateType
/// <summary>
/// Type of "mirostat" sampling to use.
/// https://github.com/basusourya/mirostat
/// </summary>
public enum MirostatType
{
/// <summary>
/// Disable Mirostat sampling
/// </summary>
Disable = 0,
MiroState = 1,
MiroState2 = 2
/// <summary>
/// Original mirostat algorithm
/// </summary>
Mirostat = 1,
/// <summary>
/// Mirostat 2.0 algorithm
/// </summary>
Mirostat2 = 2
}
}

View File

@ -229,7 +229,7 @@ namespace LLama
/// <param name="tfsZ"></param>
/// <param name="typicalP"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable,
public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
{
llama_token id = 0;
@ -240,14 +240,14 @@ namespace LLama
}
else
{
if (mirostat == MiroStateType.MiroState)
if (mirostat == MirostatType.Mirostat)
{
float mirostat_mu = 2.0f * mirostatTau;
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu);
}
else if (mirostat == MiroStateType.MiroState2)
else if (mirostat == MirostatType.Mirostat2)
{
float mirostat_mu = 2.0f * mirostatTau;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);