Removed unnecessary parameters from some low level sampler methods
This commit is contained in:
parent
f70525fec2
commit
cf4754db44
|
@ -0,0 +1,14 @@
|
|||
using System.Collections.Generic;
|
||||
|
||||
namespace LLama.Extensions
|
||||
{
|
||||
internal static class DictionaryExtensions
|
||||
{
|
||||
#if NETSTANDARD2_0
|
||||
public static TValue GetValueOrDefault<TKey, TValue>(this IReadOnlyDictionary<TKey, TValue> dictionary, TKey key, TValue defaultValue)
|
||||
{
|
||||
return dictionary.TryGetValue(key, out var value) ? value : defaultValue;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
|
||||
namespace LLama.Extensions
|
||||
{
|
||||
internal static class IEnumerableExtensions
|
||||
{
|
||||
#if NETSTANDARD2_0
|
||||
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> source, int count)
|
||||
{
|
||||
var list = source.ToList();
|
||||
|
||||
if (count >= list.Count)
|
||||
return list;
|
||||
|
||||
list.RemoveRange(0, list.Count - count);
|
||||
return list;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
|
@ -2,6 +2,8 @@
|
|||
using System.Buffers;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
using llama_token = System.Int32;
|
||||
|
||||
namespace LLama.Native
|
||||
{
|
||||
/// <summary>
|
||||
|
@ -15,9 +17,9 @@ namespace LLama.Native
|
|||
public readonly Memory<LLamaTokenData> data;
|
||||
|
||||
/// <summary>
|
||||
/// Indicates if `data` is sorted
|
||||
/// Indicates if `data` is sorted by logits in descending order. If this is false the token data is in _no particular order_.
|
||||
/// </summary>
|
||||
public readonly bool sorted;
|
||||
public bool sorted;
|
||||
|
||||
/// <summary>
|
||||
/// Create a new LLamaTokenDataArray
|
||||
|
@ -29,6 +31,20 @@ namespace LLama.Native
|
|||
data = tokens;
|
||||
sorted = isSorted;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a new LLamaTokenDataArray, copying the data from the given logits
|
||||
/// </summary>
|
||||
/// <param name="logits"></param>
|
||||
/// <returns></returns>
|
||||
public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits)
|
||||
{
|
||||
var candidates = new LLamaTokenData[logits.Length];
|
||||
for (var token_id = 0; token_id < logits.Length; token_id++)
|
||||
candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f);
|
||||
|
||||
return new LLamaTokenDataArray(candidates);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -25,12 +25,25 @@ namespace LLama.Native
|
|||
/// <param name="last_tokens"></param>
|
||||
/// <param name="last_tokens_size"></param>
|
||||
/// <param name="penalty"></param>
|
||||
[Obsolete("last_tokens_size parameter is no longer needed")]
|
||||
public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float penalty)
|
||||
{
|
||||
llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
|
||||
/// <param name="last_tokens"></param>
|
||||
/// <param name="penalty"></param>
|
||||
public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float penalty)
|
||||
{
|
||||
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
|
||||
using var last_tokens_handle = last_tokens.Pin();
|
||||
|
||||
NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, penalty);
|
||||
NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -42,12 +55,26 @@ namespace LLama.Native
|
|||
/// <param name="last_tokens_size"></param>
|
||||
/// <param name="alpha_frequency"></param>
|
||||
/// <param name="alpha_presence"></param>
|
||||
[Obsolete("last_tokens_size parameter is no longer needed")]
|
||||
public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
|
||||
{
|
||||
llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
|
||||
/// <param name="last_tokens"></param>
|
||||
/// <param name="alpha_frequency"></param>
|
||||
/// <param name="alpha_presence"></param>
|
||||
public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float alpha_frequency, float alpha_presence)
|
||||
{
|
||||
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
|
||||
using var last_tokens_handle = last_tokens.Pin();
|
||||
|
||||
NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, alpha_frequency, alpha_presence);
|
||||
NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
Loading…
Reference in New Issue