Added a converter similar to the Open AI one

This commit is contained in:
Ian Foutz 2023-11-18 21:42:34 -06:00
parent 8540c8d220
commit 060d7c273d
4 changed files with 168 additions and 4 deletions

View File

@ -1,4 +1,6 @@
using Microsoft.SemanticKernel.AI;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace LLamaSharp.SemanticKernel.ChatCompletion;
@ -8,12 +10,14 @@ public class ChatRequestSettings : AIRequestSettings
/// Temperature controls the randomness of the completion.
/// The higher the temperature, the more random the completion.
/// </summary>
[JsonPropertyName("temperature")]
public double Temperature { get; set; } = 0;
/// <summary>
/// TopP controls the diversity of the completion.
/// The higher the TopP, the more diverse the completion.
/// </summary>
[JsonPropertyName("top_p")]
public double TopP { get; set; } = 0;
/// <summary>
@ -21,6 +25,7 @@ public class ChatRequestSettings : AIRequestSettings
/// based on whether they appear in the text so far, increasing the
/// model's likelihood to talk about new topics.
/// </summary>
[JsonPropertyName("presence_penalty")]
public double PresencePenalty { get; set; } = 0;
/// <summary>
@ -28,11 +33,13 @@ public class ChatRequestSettings : AIRequestSettings
/// based on their existing frequency in the text so far, decreasing
/// the model's likelihood to repeat the same line verbatim.
/// </summary>
[JsonPropertyName("frequency_penalty")]
public double FrequencyPenalty { get; set; } = 0;
/// <summary>
/// Sequences where the completion will stop generating further tokens.
/// </summary>
[JsonPropertyName("stop_sequences")]
public IList<string> StopSequences { get; set; } = Array.Empty<string>();
/// <summary>
@ -40,15 +47,67 @@ public class ChatRequestSettings : AIRequestSettings
/// Note: Because this parameter generates many completions, it can quickly consume your token quota.
/// Use carefully and ensure that you have reasonable settings for max_tokens and stop.
/// </summary>
[JsonPropertyName("results_per_prompt")]
public int ResultsPerPrompt { get; set; } = 1;
/// <summary>
/// The maximum number of tokens to generate in the completion.
/// </summary>
[JsonPropertyName("max_tokens")]
public int? MaxTokens { get; set; }
/// <summary>
/// Modify the likelihood of specified tokens appearing in the completion.
/// </summary>
[JsonPropertyName("token_selection_biases")]
public IDictionary<int, int> TokenSelectionBiases { get; set; } = new Dictionary<int, int>();
/// <summary>
/// Create a new settings object with the values from another settings object.
/// </summary>
/// <param name="requestSettings">Template configuration</param>
/// <param name="defaultMaxTokens">Default max tokens</param>
/// <returns>An instance of OpenAIRequestSettings</returns>
public static ChatRequestSettings FromRequestSettings(AIRequestSettings? requestSettings, int? defaultMaxTokens = null)
{
if (requestSettings is null)
{
return new ChatRequestSettings()
{
MaxTokens = defaultMaxTokens
};
}
if (requestSettings is ChatRequestSettings requestSettingsChatRequestSettings)
{
return requestSettingsChatRequestSettings;
}
var json = JsonSerializer.Serialize(requestSettings);
var chatRequestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, s_options);
if (chatRequestSettings is not null)
{
return chatRequestSettings;
}
throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(ChatRequestSettings)}", nameof(requestSettings));
}
private static readonly JsonSerializerOptions s_options = CreateOptions();
private static JsonSerializerOptions CreateOptions()
{
JsonSerializerOptions options = new()
{
WriteIndented = true,
MaxDepth = 20,
AllowTrailingCommas = true,
PropertyNameCaseInsensitive = true,
ReadCommentHandling = JsonCommentHandling.Skip,
Converters = { new ChatRequestSettingsConverter() }
};
return options;
}
}

View File

@ -0,0 +1,105 @@
using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace LLamaSharp.SemanticKernel.ChatCompletion;
/// <summary>
/// JSON converter for <see cref="OpenAIRequestSettings"/>
/// </summary>
public class ChatRequestSettingsConverter : JsonConverter<ChatRequestSettings>
{
/// <inheritdoc/>
public override ChatRequestSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var requestSettings = new ChatRequestSettings();
while (reader.Read() && reader.TokenType != JsonTokenType.EndObject)
{
if (reader.TokenType == JsonTokenType.PropertyName)
{
string? propertyName = reader.GetString();
if (propertyName is not null)
{
// normalise property name to uppercase
propertyName = propertyName.ToUpperInvariant();
}
reader.Read();
switch (propertyName)
{
case "TEMPERATURE":
requestSettings.Temperature = reader.GetDouble();
break;
case "TOPP":
case "TOP_P":
requestSettings.TopP = reader.GetDouble();
break;
case "FREQUENCYPENALTY":
case "FREQUENCY_PENALTY":
requestSettings.FrequencyPenalty = reader.GetDouble();
break;
case "PRESENCEPENALTY":
case "PRESENCE_PENALTY":
requestSettings.PresencePenalty = reader.GetDouble();
break;
case "MAXTOKENS":
case "MAX_TOKENS":
requestSettings.MaxTokens = reader.GetInt32();
break;
case "STOPSEQUENCES":
case "STOP_SEQUENCES":
requestSettings.StopSequences = JsonSerializer.Deserialize<IList<string>>(ref reader, options) ?? Array.Empty<string>();
break;
case "RESULTSPERPROMPT":
case "RESULTS_PER_PROMPT":
requestSettings.ResultsPerPrompt = reader.GetInt32();
break;
case "TOKENSELECTIONBIASES":
case "TOKEN_SELECTION_BIASES":
requestSettings.TokenSelectionBiases = JsonSerializer.Deserialize<IDictionary<int, int>>(ref reader, options) ?? new Dictionary<int, int>();
break;
case "SERVICEID":
case "SERVICE_ID":
requestSettings.ServiceId = reader.GetString();
break;
default:
reader.Skip();
break;
}
}
}
return requestSettings;
}
/// <inheritdoc/>
public override void Write(Utf8JsonWriter writer, ChatRequestSettings value, JsonSerializerOptions options)
{
writer.WriteStartObject();
writer.WriteNumber("temperature", value.Temperature);
writer.WriteNumber("top_p", value.TopP);
writer.WriteNumber("frequency_penalty", value.FrequencyPenalty);
writer.WriteNumber("presence_penalty", value.PresencePenalty);
if (value.MaxTokens is null)
{
writer.WriteNull("max_tokens");
}
else
{
writer.WriteNumber("max_tokens", (decimal)value.MaxTokens);
}
writer.WritePropertyName("stop_sequences");
JsonSerializer.Serialize(writer, value.StopSequences, options);
writer.WriteNumber("results_per_prompt", value.ResultsPerPrompt);
writer.WritePropertyName("token_selection_biases");
JsonSerializer.Serialize(writer, value.TokenSelectionBiases, options);
writer.WriteString("service_id", value.ServiceId);
writer.WriteEndObject();
}
}

View File

@ -61,7 +61,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
public Task<IReadOnlyList<IChatResult>> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default)
{
var settings = requestSettings != null
? (ChatRequestSettings)requestSettings
? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;
// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
@ -76,7 +76,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
#pragma warning restore CS1998
{
var settings = requestSettings != null
? (ChatRequestSettings)requestSettings
? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;
// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.

View File

@ -21,7 +21,7 @@ public sealed class LLamaSharpTextCompletion : ITextCompletion
public async Task<IReadOnlyList<ITextResult>> GetCompletionsAsync(string text, AIRequestSettings? requestSettings, CancellationToken cancellationToken = default)
{
var settings = (ChatRequestSettings?)requestSettings;
var settings = ChatRequestSettings.FromRequestSettings(requestSettings);
var result = executor.InferAsync(text, settings?.ToLLamaSharpInferenceParams(), cancellationToken);
return await Task.FromResult(new List<ITextResult> { new LLamaTextResult(result) }.AsReadOnly()).ConfigureAwait(false);
}
@ -30,7 +30,7 @@ public sealed class LLamaSharpTextCompletion : ITextCompletion
public async IAsyncEnumerable<ITextStreamingResult> GetStreamingCompletionsAsync(string text, AIRequestSettings? requestSettings,[EnumeratorCancellation] CancellationToken cancellationToken = default)
#pragma warning restore CS1998
{
var settings = (ChatRequestSettings?)requestSettings;
var settings = ChatRequestSettings.FromRequestSettings(requestSettings);
var result = executor.InferAsync(text, settings?.ToLLamaSharpInferenceParams(), cancellationToken);
yield return new LLamaTextResult(result);
}