refactor: allow customized logger.

This commit is contained in:
Yaohui Liu 2023-06-12 03:11:44 +08:00
parent 3bf74ec9b9
commit b567399b65
No known key found for this signature in database
GPG Key ID: E86D01E1809BD23E
4 changed files with 123 additions and 49 deletions

View File

@ -1,15 +1,35 @@
using System;
using System.Diagnostics;
using System.IO;
using static LLama.Common.ILLamaLogger;
namespace LLama.Types;
namespace LLama.Common;
public interface ILLamaLogger
{
public enum LogLevel
{
Info,
Debug,
Warning,
Error
}
/// <summary>
/// Write the log in cosutomized way
/// </summary>
/// <param name="source">The source of the log. It may be a method name or class name.</param>
/// <param name="message">The message.</param>
/// <param name="level">The log level.</param>
void Log(string source, string message, LogLevel level);
}
/// <summary>
/// The logger of LLamaSharp. On default it write to console. User methods of `LLamaLogger.Default` to change the behavior.
/// The default logger of LLamaSharp. On default it write to console. User methods of `LLamaLogger.Default` to change the behavior.
/// It's more recommended to inherit `ILLamaLogger` to cosutomize the behavior.
/// </summary>
public sealed class LLamaLogger
public sealed class LLamaDefaultLogger : ILLamaLogger
{
private static readonly Lazy<LLamaLogger> _instance = new Lazy<LLamaLogger>(() => new LLamaLogger());
private static readonly Lazy<LLamaDefaultLogger> _instance = new Lazy<LLamaDefaultLogger>(() => new LLamaDefaultLogger());
private bool _toConsole = true;
private bool _toFile = false;
@ -17,26 +37,26 @@ public sealed class LLamaLogger
private FileStream? _fileStream = null;
private StreamWriter _fileWriter = null;
public static LLamaLogger Default => _instance.Value;
public static LLamaDefaultLogger Default => _instance.Value;
private LLamaLogger()
private LLamaDefaultLogger()
{
}
public LLamaLogger EnableConsole()
public LLamaDefaultLogger EnableConsole()
{
_toConsole = true;
return this;
}
public LLamaLogger DisableConsole()
public LLamaDefaultLogger DisableConsole()
{
_toConsole = false;
return this;
}
public LLamaLogger EnableFile(string filename, FileMode mode = FileMode.Append)
public LLamaDefaultLogger EnableFile(string filename, FileMode mode = FileMode.Append)
{
_fileStream = new FileStream(filename, mode, FileAccess.Write);
_fileWriter = new StreamWriter(_fileStream);
@ -44,14 +64,14 @@ public sealed class LLamaLogger
return this;
}
public LLamaLogger DisableFile(string filename)
public LLamaDefaultLogger DisableFile(string filename)
{
if(_fileWriter is not null)
if (_fileWriter is not null)
{
_fileWriter.Close();
_fileWriter = null;
}
if(_fileStream is not null)
if (_fileStream is not null)
{
_fileStream.Close();
_fileStream = null;
@ -60,6 +80,26 @@ public sealed class LLamaLogger
return this;
}
public void Log(string source, string message, LogLevel level)
{
if (level == LogLevel.Info)
{
Info(message);
}
else if (level == LogLevel.Debug)
{
}
else if (level == LogLevel.Warning)
{
Warn(message);
}
else if (level == LogLevel.Error)
{
Error(message);
}
}
public void Info(string message)
{
message = MessageFormat("info", message);

View File

@ -6,30 +6,30 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
namespace LLama
{
using llama_token = Int32;
public abstract class ChatExecutorBase: ILLamaExecutor
public abstract class ChatExecutorBase : ILLamaExecutor
{
protected readonly LLamaModel _model;
protected ILLamaLogger? _logger;
protected int _pastTokensCount; // n_past
protected int _consumedTokensCount; // n_consume
protected int _n_session_consumed;
protected int _n_matching_session_tokens;
protected string _pathSession;
protected string? _pathSession;
protected List<llama_token> _embeds = new(); // embd
protected List<llama_token> _embed_inps = new();
protected List<llama_token> _session_tokens = new();
protected FixedSizeQuene<llama_token> _last_n_tokens;
public LLamaModel Model => _model;
protected ChatExecutorBase(LLamaModel model)
protected ChatExecutorBase(LLamaModel model, ILLamaLogger? logger = null)
{
_model = model;
_logger = logger;
_pastTokensCount = 0;
_consumedTokensCount = 0;
_n_session_consumed = 0;
@ -47,14 +47,49 @@ namespace LLama
}
if (File.Exists(filename))
{
_logger?.Log("LLamaExecutor", $"Attempting to load saved session from {filename}", ILLamaLogger.LogLevel.Info);
llama_token[] session_tokens = new llama_token[_model.ContextSize];
ulong n_token_count_out = 0;
if (!NativeApi.llama_load_session_file(_model.NativeHandle, _pathSession, session_tokens, (ulong)_model.ContextSize, &n_token_count_out))
{
_logger?.Log("LLamaExecutor", $"Failed to load session file {filename}", ILLamaLogger.LogLevel.Error);
throw new RuntimeError($"Failed to load session file {_pathSession}");
}
_session_tokens = session_tokens.Take((int)n_token_count_out).ToList();
_logger?.Log("LLamaExecutor", $"Loaded a session with prompt size of {session_tokens.Length} tokens", ILLamaLogger.LogLevel.Info);
}
else
{
_logger?.Log("LLamaExecutor", $"Session file does not exist, will create", ILLamaLogger.LogLevel.Warning);
}
_n_matching_session_tokens = 0;
if (_session_tokens.Count > 0)
{
foreach (var id in _session_tokens)
{
if (_n_matching_session_tokens >= _embed_inps.Count || id != _embed_inps[_n_matching_session_tokens])
{
break;
}
_n_matching_session_tokens++;
}
if (_n_matching_session_tokens >= _embed_inps.Count)
{
_logger?.Log("LLamaExecutor", $"Session file has exact match for prompt!", ILLamaLogger.LogLevel.Info);
}
else if (_n_matching_session_tokens < _embed_inps.Count / 2)
{
_logger?.Log("LLamaExecutor", $"session file has low similarity to prompt ({_n_matching_session_tokens}" +
$" / {_embed_inps.Count} tokens); will mostly be reevaluated", ILLamaLogger.LogLevel.Warning);
}
else
{
_logger?.Log("LLamaExecutor", $"Session file matches {_n_matching_session_tokens} / " +
$"{_embed_inps.Count} tokens of prompt", ILLamaLogger.LogLevel.Info);
}
}
return this;
}
@ -169,7 +204,7 @@ namespace LLama
}
public virtual async IAsyncEnumerable<string> InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach(var result in Infer(text, inferenceParams, cancellationToken))
foreach (var result in Infer(text, inferenceParams, cancellationToken))
{
yield return result;
}

View File

@ -1,7 +1,6 @@
using LLama.Exceptions;
using LLama.Native;
using LLama.OldVersion;
using LLama.Types;
using LLama.Extensions;
using System;
using System.Collections.Generic;
@ -17,7 +16,7 @@ namespace LLama
public class LLamaModel: IDisposable
{
// TODO: expose more properties.
LLamaLogger _logger;
ILLamaLogger? _logger;
Encoding _encoding;
SafeLLamaContextHandle _ctx;
public int ContextSize { get; }
@ -25,12 +24,12 @@ namespace LLama
public SafeLLamaContextHandle NativeHandle => _ctx;
public Encoding Encoding => _encoding;
public LLamaModel(ModelParams Params, string encoding = "UTF-8")
public LLamaModel(ModelParams Params, string encoding = "UTF-8", ILLamaLogger? logger = null)
{
_logger = LLamaLogger.Default;
_logger = logger;
this.Params = Params;
_encoding = Encoding.GetEncoding(encoding);
_logger.Info($"Initializing LLama model with params: {this.Params}");
_logger?.Log(nameof(LLamaModel), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info);
_ctx = Utils.InitLLamaContextFromModelParams(this.Params);
ContextSize = NativeApi.llama_n_ctx(_ctx);
}
@ -210,7 +209,7 @@ namespace LLama
if(Utils.Eval(_ctx, tokens, i, n_eval, pastTokensCount, Params.Threads) != 0)
{
_logger.Error($"Failed to eval.");
_logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error);
throw new RuntimeError("Failed to eval.");
}

View File

@ -1,5 +1,4 @@
using LLama.Exceptions;
using LLama.Types;
using LLama.Extensions;
using LLama.Native;
using System;
@ -8,6 +7,7 @@ using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using LLama.Common;
namespace LLama.OldVersion
{
@ -184,12 +184,12 @@ namespace LLama.OldVersion
{
if (verbose)
{
LLamaLogger.Default.Info($"Attempting to load saved session from '{_path_session}'");
LLamaDefaultLogger.Default.Info($"Attempting to load saved session from '{_path_session}'");
}
if (!File.Exists(_path_session))
{
LLamaLogger.Default.Warn("Session file does not exist, will create.");
LLamaDefaultLogger.Default.Warn("Session file does not exist, will create.");
}
llama_token[] session_tokens = new llama_token[@params.n_ctx];
@ -201,7 +201,7 @@ namespace LLama.OldVersion
_session_tokens = session_tokens.Take((int)n_token_count_out).ToList();
if (verbose)
{
LLamaLogger.Default.Info($"Loaded a session with prompt size of {_session_tokens.Count} tokens");
LLamaDefaultLogger.Default.Info($"Loaded a session with prompt size of {_session_tokens.Count} tokens");
}
}
@ -231,39 +231,39 @@ namespace LLama.OldVersion
if (_params.verbose_prompt)
{
LLamaLogger.Default.Info("\n");
LLamaLogger.Default.Info($"prompt: '{_params.prompt}'");
LLamaLogger.Default.Info($"number of tokens in prompt = {_embed_inp.Count}");
LLamaDefaultLogger.Default.Info("\n");
LLamaDefaultLogger.Default.Info($"prompt: '{_params.prompt}'");
LLamaDefaultLogger.Default.Info($"number of tokens in prompt = {_embed_inp.Count}");
for (int i = 0; i < _embed_inp.Count; i++)
{
LLamaLogger.Default.Info($"{_embed_inp[i]} -> '{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}'");
LLamaDefaultLogger.Default.Info($"{_embed_inp[i]} -> '{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}'");
}
if (_params.n_keep > 0)
{
LLamaLogger.Default.Info($"static prompt based on n_keep: '");
LLamaDefaultLogger.Default.Info($"static prompt based on n_keep: '");
for (int i = 0; i < _params.n_keep; i++)
{
LLamaLogger.Default.Info($"{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}");
LLamaDefaultLogger.Default.Info($"{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}");
}
LLamaLogger.Default.Info("\n");
LLamaDefaultLogger.Default.Info("\n");
}
LLamaLogger.Default.Info("\n");
LLamaDefaultLogger.Default.Info("\n");
}
if (_params.interactive && verbose)
{
LLamaLogger.Default.Info("interactive mode on.");
LLamaDefaultLogger.Default.Info("interactive mode on.");
}
if (verbose)
{
LLamaLogger.Default.Info($"sampling: repeat_last_n = {_params.repeat_last_n}, " +
LLamaDefaultLogger.Default.Info($"sampling: repeat_last_n = {_params.repeat_last_n}, " +
$"repeat_penalty = {_params.repeat_penalty}, presence_penalty = {_params.presence_penalty}, " +
$"frequency_penalty = {_params.frequency_penalty}, top_k = {_params.top_k}, tfs_z = {_params.tfs_z}," +
$" top_p = {_params.top_p}, typical_p = {_params.typical_p}, temp = {_params.temp}, mirostat = {_params.mirostat}," +
$" mirostat_lr = {_params.mirostat_eta}, mirostat_ent = {_params.mirostat_tau}");
LLamaLogger.Default.Info($"generate: n_ctx = {_n_ctx}, n_batch = {_params.n_batch}, n_predict = {_params.n_predict}, " +
LLamaDefaultLogger.Default.Info($"generate: n_ctx = {_n_ctx}, n_batch = {_params.n_batch}, n_predict = {_params.n_predict}, " +
$"n_keep = {_params.n_keep}");
LLamaLogger.Default.Info("\n");
LLamaDefaultLogger.Default.Info("\n");
}
_last_n_tokens = Enumerable.Repeat(0, _n_ctx).ToList();
@ -272,7 +272,7 @@ namespace LLama.OldVersion
{
if (verbose)
{
LLamaLogger.Default.Info("== Running in interactive mode. ==");
LLamaDefaultLogger.Default.Info("== Running in interactive mode. ==");
}
_is_interacting = _params.interactive_first;
}
@ -316,16 +316,16 @@ namespace LLama.OldVersion
}
if (n_matching_session_tokens >= (ulong)_embed_inp.Count)
{
LLamaLogger.Default.Info("Session file has exact match for prompt!");
LLamaDefaultLogger.Default.Info("Session file has exact match for prompt!");
}
else if (n_matching_session_tokens < (ulong)(_embed_inp.Count / 2))
{
LLamaLogger.Default.Warn($"session file has low similarity to prompt ({n_matching_session_tokens} " +
LLamaDefaultLogger.Default.Warn($"session file has low similarity to prompt ({n_matching_session_tokens} " +
$"/ {_embed_inp.Count} tokens); will mostly be reevaluated.");
}
else
{
LLamaLogger.Default.Info($"Session file matches {n_matching_session_tokens} / {_embed_inp.Count} " +
LLamaDefaultLogger.Default.Info($"Session file matches {n_matching_session_tokens} / {_embed_inp.Count} " +
$"tokens of prompt.");
}
}
@ -511,7 +511,7 @@ namespace LLama.OldVersion
{
if (_verbose)
{
LLamaLogger.Default.Warn("In interacting when calling the model, automatically changed it.");
LLamaDefaultLogger.Default.Warn("In interacting when calling the model, automatically changed it.");
}
_is_interacting = false;
}
@ -581,7 +581,7 @@ namespace LLama.OldVersion
var array = _embed.Skip(i).ToArray();
if (NativeApi.llama_eval(_ctx, array, n_eval, _n_past, _params.n_threads) != 0)
{
LLamaLogger.Default.Error($"Failed to eval.");
LLamaDefaultLogger.Default.Error($"Failed to eval.");
throw new RuntimeError("Failed to eval.");
}
@ -776,7 +776,7 @@ namespace LLama.OldVersion
}
else
{
LLamaLogger.Default.Info(" [end of text]");
LLamaDefaultLogger.Default.Info(" [end of text]");
}
}
@ -790,7 +790,7 @@ namespace LLama.OldVersion
if (!string.IsNullOrEmpty(_path_session) && _params.prompt_cache_all)
{
LLamaLogger.Default.Info($"saving final output to session file {_path_session}");
LLamaDefaultLogger.Default.Info($"saving final output to session file {_path_session}");
var session_token_array = _session_tokens.ToArray();
NativeApi.llama_save_session_file(_ctx, _path_session, session_token_array, (ulong)session_token_array.Length);
}