Removed unnecessary parameters from some low level sampler methods

This commit is contained in:
Martin Evans 2023-08-26 21:38:24 +01:00
parent f70525fec2
commit cf4754db44
4 changed files with 82 additions and 4 deletions

View File

@ -0,0 +1,14 @@
using System.Collections.Generic;
namespace LLama.Extensions
{
internal static class DictionaryExtensions
{
#if NETSTANDARD2_0
public static TValue GetValueOrDefault<TKey, TValue>(this IReadOnlyDictionary<TKey, TValue> dictionary, TKey key, TValue defaultValue)
{
return dictionary.TryGetValue(key, out var value) ? value : defaultValue;
}
#endif
}
}

View File

@ -0,0 +1,21 @@
using System.Collections.Generic;
using System.Linq;
namespace LLama.Extensions
{
internal static class IEnumerableExtensions
{
#if NETSTANDARD2_0
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> source, int count)
{
var list = source.ToList();
if (count >= list.Count)
return list;
list.RemoveRange(0, list.Count - count);
return list;
}
#endif
}
}

View File

@ -2,6 +2,8 @@
using System.Buffers;
using System.Runtime.InteropServices;
using llama_token = System.Int32;
namespace LLama.Native
{
/// <summary>
@ -15,9 +17,9 @@ namespace LLama.Native
public readonly Memory<LLamaTokenData> data;
/// <summary>
/// Indicates if `data` is sorted
/// Indicates if `data` is sorted by logits in descending order. If this is false the token data is in _no particular order_.
/// </summary>
public readonly bool sorted;
public bool sorted;
/// <summary>
/// Create a new LLamaTokenDataArray
@ -29,6 +31,20 @@ namespace LLama.Native
data = tokens;
sorted = isSorted;
}
/// <summary>
/// Create a new LLamaTokenDataArray, copying the data from the given logits
/// </summary>
/// <param name="logits"></param>
/// <returns></returns>
public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits)
{
var candidates = new LLamaTokenData[logits.Length];
for (var token_id = 0; token_id < logits.Length; token_id++)
candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f);
return new LLamaTokenDataArray(candidates);
}
}
/// <summary>

View File

@ -25,12 +25,25 @@ namespace LLama.Native
/// <param name="last_tokens"></param>
/// <param name="last_tokens_size"></param>
/// <param name="penalty"></param>
[Obsolete("last_tokens_size parameter is no longer needed")]
public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float penalty)
{
llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty);
}
/// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="penalty"></param>
public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float penalty)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
using var last_tokens_handle = last_tokens.Pin();
NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, penalty);
NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty);
}
/// <summary>
@ -42,12 +55,26 @@ namespace LLama.Native
/// <param name="last_tokens_size"></param>
/// <param name="alpha_frequency"></param>
/// <param name="alpha_presence"></param>
[Obsolete("last_tokens_size parameter is no longer needed")]
public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
{
llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence);
}
/// <summary>
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="alpha_frequency"></param>
/// <param name="alpha_presence"></param>
public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float alpha_frequency, float alpha_presence)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
using var last_tokens_handle = last_tokens.Pin();
NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, alpha_frequency, alpha_presence);
NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence);
}
/// <summary>