2023-06-16 04:06:47 +08:00
using LLama.Abstractions ;
using LLama.Common ;
2023-06-12 18:07:41 +08:00
using System.Collections.Generic ;
using System.Linq ;
using System.Text ;
2024-03-17 22:34:36 +08:00
using System.Text.Json.Serialization ;
2023-06-12 18:07:41 +08:00
namespace LLama
{
2023-06-20 02:38:57 +08:00
/// <summary>
/// A class that contains all the transforms provided internally by LLama.
/// </summary>
2023-06-12 18:07:41 +08:00
public class LLamaTransforms
{
2023-06-16 04:23:58 +08:00
/// <summary>
/// The default history transform.
/// Uses plain text with the following format:
/// [Author]: [Message]
/// </summary>
2023-06-12 18:07:41 +08:00
public class DefaultHistoryTransform : IHistoryTransform
{
2023-10-14 06:54:01 +08:00
private const string defaultUserName = "User" ;
private const string defaultAssistantName = "Assistant" ;
private const string defaultSystemName = "System" ;
private const string defaultUnknownName = "??" ;
private readonly string _userName ;
private readonly string _assistantName ;
private readonly string _systemName ;
private readonly string _unknownName ;
private readonly bool _isInstructMode ;
2023-06-12 18:07:41 +08:00
2024-03-17 22:34:36 +08:00
public string UserName = > _userName ;
public string AssistantName = > _assistantName ;
public string SystemName = > _systemName ;
public string UnknownName = > _unknownName ;
public bool IsInstructMode = > _isInstructMode ;
2023-06-20 02:38:57 +08:00
/// <summary>
///
/// </summary>
/// <param name="userName"></param>
/// <param name="assistantName"></param>
/// <param name="systemName"></param>
/// <param name="unknownName"></param>
/// <param name="isInstructMode"></param>
2023-06-12 18:07:41 +08:00
public DefaultHistoryTransform ( string? userName = null , string? assistantName = null ,
string? systemName = null , string? unknownName = null , bool isInstructMode = false )
{
_userName = userName ? ? defaultUserName ;
_assistantName = assistantName ? ? defaultAssistantName ;
_systemName = systemName ? ? defaultSystemName ;
_unknownName = unknownName ? ? defaultUnknownName ;
_isInstructMode = isInstructMode ;
}
2024-03-17 19:37:02 +08:00
/// <inheritdoc />
public IHistoryTransform Clone ( )
{
return new DefaultHistoryTransform ( _userName , _assistantName , _systemName , _unknownName , _isInstructMode ) ;
}
2023-06-20 02:38:57 +08:00
/// <inheritdoc />
2023-06-12 18:07:41 +08:00
public virtual string HistoryToText ( ChatHistory history )
{
StringBuilder sb = new ( ) ;
foreach ( var message in history . Messages )
{
if ( message . AuthorRole = = AuthorRole . User )
{
sb . AppendLine ( $"{_userName}: {message.Content}" ) ;
}
else if ( message . AuthorRole = = AuthorRole . System )
{
sb . AppendLine ( $"{_systemName}: {message.Content}" ) ;
}
else if ( message . AuthorRole = = AuthorRole . Unknown )
{
sb . AppendLine ( $"{_unknownName}: {message.Content}" ) ;
}
else if ( message . AuthorRole = = AuthorRole . Assistant )
{
sb . AppendLine ( $"{_assistantName}: {message.Content}" ) ;
}
}
return sb . ToString ( ) ;
}
2023-06-20 02:38:57 +08:00
/// <inheritdoc />
2023-06-12 18:07:41 +08:00
public virtual ChatHistory TextToHistory ( AuthorRole role , string text )
{
ChatHistory history = new ChatHistory ( ) ;
history . AddMessage ( role , TrimNamesFromText ( text , role ) ) ;
return history ;
}
2023-06-20 02:38:57 +08:00
/// <summary>
/// Drop the name at the beginning and the end of the text.
/// </summary>
/// <param name="text"></param>
/// <param name="role"></param>
/// <returns></returns>
2023-06-12 18:07:41 +08:00
public virtual string TrimNamesFromText ( string text , AuthorRole role )
{
if ( role = = AuthorRole . User & & text . StartsWith ( $"{_userName}:" ) )
{
text = text . Substring ( $"{_userName}:" . Length ) . TrimStart ( ) ;
}
else if ( role = = AuthorRole . Assistant & & text . EndsWith ( $"{_assistantName}:" ) )
{
text = text . Substring ( 0 , text . Length - $"{_assistantName}:" . Length ) . TrimEnd ( ) ;
}
if ( _isInstructMode & & role = = AuthorRole . Assistant & & text . EndsWith ( "\n> " ) )
{
text = text . Substring ( 0 , text . Length - "\n> " . Length ) . TrimEnd ( ) ;
}
return text ;
}
}
/// <summary>
/// A text input transform that only trims the text.
/// </summary>
2023-10-14 06:54:01 +08:00
public class NaiveTextInputTransform
: ITextTransform
2023-06-12 18:07:41 +08:00
{
2023-06-16 04:23:58 +08:00
/// <inheritdoc />
2023-06-12 18:07:41 +08:00
public string Transform ( string text )
{
return text . Trim ( ) ;
}
2024-03-17 19:37:02 +08:00
/// <inheritdoc />
public ITextTransform Clone ( )
{
return new NaiveTextInputTransform ( ) ;
}
2023-06-12 18:07:41 +08:00
}
2023-10-14 06:54:01 +08:00
2023-06-16 04:23:58 +08:00
/// <summary>
/// A no-op text input transform.
/// </summary>
2023-10-14 06:54:01 +08:00
public class EmptyTextOutputStreamTransform
: ITextStreamTransform
2023-06-12 18:07:41 +08:00
{
2023-06-20 02:38:57 +08:00
/// <inheritdoc />
2023-06-12 18:07:41 +08:00
public IAsyncEnumerable < string > TransformAsync ( IAsyncEnumerable < string > tokens )
{
return tokens ;
}
2024-03-17 19:37:02 +08:00
/// <inheritdoc />
public ITextStreamTransform Clone ( )
{
return new EmptyTextOutputStreamTransform ( ) ;
}
2023-06-12 18:07:41 +08:00
}
2023-10-14 06:54:01 +08:00
2023-06-16 04:23:58 +08:00
/// <summary>
/// A text output transform that removes the keywords from the response.
/// </summary>
2023-06-12 18:07:41 +08:00
public class KeywordTextOutputStreamTransform : ITextStreamTransform
{
2023-10-14 06:54:01 +08:00
private readonly HashSet < string > _keywords ;
private readonly int _maxKeywordLength ;
private readonly bool _removeAllMatchedTokens ;
2023-06-12 18:07:41 +08:00
2024-03-17 22:34:36 +08:00
/// <summary>
/// Keywords that you want to remove from the response.
/// This property is used for JSON serialization.
/// </summary>
[JsonPropertyName("keywords")]
public HashSet < string > Keywords = > _keywords ;
/// <summary>
/// Maximum length of the keywords.
/// This property is used for JSON serialization.
/// </summary>
[JsonPropertyName("maxKeywordLength")]
public int MaxKeywordLength = > _maxKeywordLength ;
/// <summary>
/// If set to true, when getting a matched keyword, all the related tokens will be removed.
/// Otherwise only the part of keyword will be removed.
/// This property is used for JSON serialization.
/// </summary>
[JsonPropertyName("removeAllMatchedTokens")]
public bool RemoveAllMatchedTokens = > _removeAllMatchedTokens ;
/// <summary>
/// JSON constructor.
/// </summary>
[JsonConstructor]
public KeywordTextOutputStreamTransform (
HashSet < string > keywords ,
int maxKeywordLength ,
bool removeAllMatchedTokens )
{
_keywords = new ( keywords ) ;
_maxKeywordLength = maxKeywordLength ;
_removeAllMatchedTokens = removeAllMatchedTokens ;
}
2023-06-12 18:07:41 +08:00
/// <summary>
///
/// </summary>
/// <param name="keywords">Keywords that you want to remove from the response.</param>
/// <param name="redundancyLength">The extra length when searching for the keyword. For example, if your only keyword is "highlight",
/// maybe the token you get is "\r\nhighligt". In this condition, if redundancyLength=0, the token cannot be successfully matched because the length of "\r\nhighligt" (10)
2023-07-20 23:07:53 +08:00
/// has already exceeded the maximum length of the keywords (8). On the contrary, setting redundancyLengyh >= 2 leads to successful match.
/// The larger the redundancyLength is, the lower the processing speed. But as an experience, it won't introduce too much performance impact when redundancyLength <= 5 </param>
2023-06-12 18:07:41 +08:00
/// <param name="removeAllMatchedTokens">If set to true, when getting a matched keyword, all the related tokens will be removed. Otherwise only the part of keyword will be removed.</param>
public KeywordTextOutputStreamTransform ( IEnumerable < string > keywords , int redundancyLength = 3 , bool removeAllMatchedTokens = false )
{
_keywords = new ( keywords ) ;
2023-07-25 07:19:30 +08:00
_maxKeywordLength = _keywords . Max ( x = > x . Length ) + redundancyLength ;
2023-07-20 23:29:54 +08:00
_maxKeywordLength = _keywords . Select ( x = > x . Length ) . Max ( ) + redundancyLength ;
2023-06-12 18:07:41 +08:00
_removeAllMatchedTokens = removeAllMatchedTokens ;
}
2024-03-17 19:37:02 +08:00
/// <inheritdoc />
public ITextStreamTransform Clone ( )
{
return new KeywordTextOutputStreamTransform ( _keywords , _maxKeywordLength , _removeAllMatchedTokens ) ;
}
2023-06-16 04:23:58 +08:00
/// <inheritdoc />
2023-06-12 18:07:41 +08:00
public async IAsyncEnumerable < string > TransformAsync ( IAsyncEnumerable < string > tokens )
{
var window = new Queue < string > ( ) ;
await foreach ( var s in tokens )
{
window . Enqueue ( s ) ;
var current = string . Join ( "" , window ) ;
if ( _keywords . Any ( x = > current . Contains ( x ) ) )
{
2024-04-29 04:12:19 +08:00
var matchedKeywords = _keywords . Where ( x = > current . Contains ( x ) ) ;
2023-06-12 18:07:41 +08:00
int total = window . Count ;
for ( int i = 0 ; i < total ; i + + )
{
window . Dequeue ( ) ;
}
if ( ! _removeAllMatchedTokens )
{
2024-04-29 04:12:19 +08:00
foreach ( var keyword in matchedKeywords )
{
current = current . Replace ( keyword , "" ) ;
}
yield return current ;
2023-06-12 18:07:41 +08:00
}
}
if ( current . Length > = _maxKeywordLength )
{
int total = window . Count ;
for ( int i = 0 ; i < total ; i + + )
{
yield return window . Dequeue ( ) ;
}
}
}
int totalCount = window . Count ;
for ( int i = 0 ; i < totalCount ; i + + )
{
yield return window . Dequeue ( ) ;
}
}
}
}
}