refactor: allow customized logger.

This commit is contained in:
Yaohui Liu 2023-06-12 03:11:44 +08:00
parent 3bf74ec9b9
commit b567399b65
No known key found for this signature in database
GPG Key ID: E86D01E1809BD23E
4 changed files with 123 additions and 49 deletions

View File

@ -1,15 +1,35 @@
using System; using System;
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.IO;
using static LLama.Common.ILLamaLogger;
namespace LLama.Types; namespace LLama.Common;
public interface ILLamaLogger
{
public enum LogLevel
{
Info,
Debug,
Warning,
Error
}
/// <summary>
/// Write the log in cosutomized way
/// </summary>
/// <param name="source">The source of the log. It may be a method name or class name.</param>
/// <param name="message">The message.</param>
/// <param name="level">The log level.</param>
void Log(string source, string message, LogLevel level);
}
/// <summary> /// <summary>
/// The logger of LLamaSharp. On default it write to console. User methods of `LLamaLogger.Default` to change the behavior. /// The default logger of LLamaSharp. On default it write to console. User methods of `LLamaLogger.Default` to change the behavior.
/// It's more recommended to inherit `ILLamaLogger` to cosutomize the behavior.
/// </summary> /// </summary>
public sealed class LLamaLogger public sealed class LLamaDefaultLogger : ILLamaLogger
{ {
private static readonly Lazy<LLamaLogger> _instance = new Lazy<LLamaLogger>(() => new LLamaLogger()); private static readonly Lazy<LLamaDefaultLogger> _instance = new Lazy<LLamaDefaultLogger>(() => new LLamaDefaultLogger());
private bool _toConsole = true; private bool _toConsole = true;
private bool _toFile = false; private bool _toFile = false;
@ -17,26 +37,26 @@ public sealed class LLamaLogger
private FileStream? _fileStream = null; private FileStream? _fileStream = null;
private StreamWriter _fileWriter = null; private StreamWriter _fileWriter = null;
public static LLamaLogger Default => _instance.Value; public static LLamaDefaultLogger Default => _instance.Value;
private LLamaLogger() private LLamaDefaultLogger()
{ {
} }
public LLamaLogger EnableConsole() public LLamaDefaultLogger EnableConsole()
{ {
_toConsole = true; _toConsole = true;
return this; return this;
} }
public LLamaLogger DisableConsole() public LLamaDefaultLogger DisableConsole()
{ {
_toConsole = false; _toConsole = false;
return this; return this;
} }
public LLamaLogger EnableFile(string filename, FileMode mode = FileMode.Append) public LLamaDefaultLogger EnableFile(string filename, FileMode mode = FileMode.Append)
{ {
_fileStream = new FileStream(filename, mode, FileAccess.Write); _fileStream = new FileStream(filename, mode, FileAccess.Write);
_fileWriter = new StreamWriter(_fileStream); _fileWriter = new StreamWriter(_fileStream);
@ -44,14 +64,14 @@ public sealed class LLamaLogger
return this; return this;
} }
public LLamaLogger DisableFile(string filename) public LLamaDefaultLogger DisableFile(string filename)
{ {
if(_fileWriter is not null) if (_fileWriter is not null)
{ {
_fileWriter.Close(); _fileWriter.Close();
_fileWriter = null; _fileWriter = null;
} }
if(_fileStream is not null) if (_fileStream is not null)
{ {
_fileStream.Close(); _fileStream.Close();
_fileStream = null; _fileStream = null;
@ -60,6 +80,26 @@ public sealed class LLamaLogger
return this; return this;
} }
public void Log(string source, string message, LogLevel level)
{
if (level == LogLevel.Info)
{
Info(message);
}
else if (level == LogLevel.Debug)
{
}
else if (level == LogLevel.Warning)
{
Warn(message);
}
else if (level == LogLevel.Error)
{
Error(message);
}
}
public void Info(string message) public void Info(string message)
{ {
message = MessageFormat("info", message); message = MessageFormat("info", message);

View File

@ -6,30 +6,30 @@ using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using System.Threading; using System.Threading;
using System.Threading.Tasks;
namespace LLama namespace LLama
{ {
using llama_token = Int32; using llama_token = Int32;
public abstract class ChatExecutorBase: ILLamaExecutor public abstract class ChatExecutorBase : ILLamaExecutor
{ {
protected readonly LLamaModel _model; protected readonly LLamaModel _model;
protected ILLamaLogger? _logger;
protected int _pastTokensCount; // n_past protected int _pastTokensCount; // n_past
protected int _consumedTokensCount; // n_consume protected int _consumedTokensCount; // n_consume
protected int _n_session_consumed; protected int _n_session_consumed;
protected int _n_matching_session_tokens; protected int _n_matching_session_tokens;
protected string _pathSession; protected string? _pathSession;
protected List<llama_token> _embeds = new(); // embd protected List<llama_token> _embeds = new(); // embd
protected List<llama_token> _embed_inps = new(); protected List<llama_token> _embed_inps = new();
protected List<llama_token> _session_tokens = new(); protected List<llama_token> _session_tokens = new();
protected FixedSizeQuene<llama_token> _last_n_tokens; protected FixedSizeQuene<llama_token> _last_n_tokens;
public LLamaModel Model => _model; public LLamaModel Model => _model;
protected ChatExecutorBase(LLamaModel model) protected ChatExecutorBase(LLamaModel model, ILLamaLogger? logger = null)
{ {
_model = model; _model = model;
_logger = logger;
_pastTokensCount = 0; _pastTokensCount = 0;
_consumedTokensCount = 0; _consumedTokensCount = 0;
_n_session_consumed = 0; _n_session_consumed = 0;
@ -47,14 +47,49 @@ namespace LLama
} }
if (File.Exists(filename)) if (File.Exists(filename))
{ {
_logger?.Log("LLamaExecutor", $"Attempting to load saved session from {filename}", ILLamaLogger.LogLevel.Info);
llama_token[] session_tokens = new llama_token[_model.ContextSize]; llama_token[] session_tokens = new llama_token[_model.ContextSize];
ulong n_token_count_out = 0; ulong n_token_count_out = 0;
if (!NativeApi.llama_load_session_file(_model.NativeHandle, _pathSession, session_tokens, (ulong)_model.ContextSize, &n_token_count_out)) if (!NativeApi.llama_load_session_file(_model.NativeHandle, _pathSession, session_tokens, (ulong)_model.ContextSize, &n_token_count_out))
{ {
_logger?.Log("LLamaExecutor", $"Failed to load session file {filename}", ILLamaLogger.LogLevel.Error);
throw new RuntimeError($"Failed to load session file {_pathSession}"); throw new RuntimeError($"Failed to load session file {_pathSession}");
} }
_session_tokens = session_tokens.Take((int)n_token_count_out).ToList(); _session_tokens = session_tokens.Take((int)n_token_count_out).ToList();
_logger?.Log("LLamaExecutor", $"Loaded a session with prompt size of {session_tokens.Length} tokens", ILLamaLogger.LogLevel.Info);
} }
else
{
_logger?.Log("LLamaExecutor", $"Session file does not exist, will create", ILLamaLogger.LogLevel.Warning);
}
_n_matching_session_tokens = 0;
if (_session_tokens.Count > 0)
{
foreach (var id in _session_tokens)
{
if (_n_matching_session_tokens >= _embed_inps.Count || id != _embed_inps[_n_matching_session_tokens])
{
break;
}
_n_matching_session_tokens++;
}
if (_n_matching_session_tokens >= _embed_inps.Count)
{
_logger?.Log("LLamaExecutor", $"Session file has exact match for prompt!", ILLamaLogger.LogLevel.Info);
}
else if (_n_matching_session_tokens < _embed_inps.Count / 2)
{
_logger?.Log("LLamaExecutor", $"session file has low similarity to prompt ({_n_matching_session_tokens}" +
$" / {_embed_inps.Count} tokens); will mostly be reevaluated", ILLamaLogger.LogLevel.Warning);
}
else
{
_logger?.Log("LLamaExecutor", $"Session file matches {_n_matching_session_tokens} / " +
$"{_embed_inps.Count} tokens of prompt", ILLamaLogger.LogLevel.Info);
}
}
return this; return this;
} }
@ -169,7 +204,7 @@ namespace LLama
} }
public virtual async IAsyncEnumerable<string> InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) public virtual async IAsyncEnumerable<string> InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{ {
foreach(var result in Infer(text, inferenceParams, cancellationToken)) foreach (var result in Infer(text, inferenceParams, cancellationToken))
{ {
yield return result; yield return result;
} }

View File

@ -1,7 +1,6 @@
using LLama.Exceptions; using LLama.Exceptions;
using LLama.Native; using LLama.Native;
using LLama.OldVersion; using LLama.OldVersion;
using LLama.Types;
using LLama.Extensions; using LLama.Extensions;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
@ -17,7 +16,7 @@ namespace LLama
public class LLamaModel: IDisposable public class LLamaModel: IDisposable
{ {
// TODO: expose more properties. // TODO: expose more properties.
LLamaLogger _logger; ILLamaLogger? _logger;
Encoding _encoding; Encoding _encoding;
SafeLLamaContextHandle _ctx; SafeLLamaContextHandle _ctx;
public int ContextSize { get; } public int ContextSize { get; }
@ -25,12 +24,12 @@ namespace LLama
public SafeLLamaContextHandle NativeHandle => _ctx; public SafeLLamaContextHandle NativeHandle => _ctx;
public Encoding Encoding => _encoding; public Encoding Encoding => _encoding;
public LLamaModel(ModelParams Params, string encoding = "UTF-8") public LLamaModel(ModelParams Params, string encoding = "UTF-8", ILLamaLogger? logger = null)
{ {
_logger = LLamaLogger.Default; _logger = logger;
this.Params = Params; this.Params = Params;
_encoding = Encoding.GetEncoding(encoding); _encoding = Encoding.GetEncoding(encoding);
_logger.Info($"Initializing LLama model with params: {this.Params}"); _logger?.Log(nameof(LLamaModel), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info);
_ctx = Utils.InitLLamaContextFromModelParams(this.Params); _ctx = Utils.InitLLamaContextFromModelParams(this.Params);
ContextSize = NativeApi.llama_n_ctx(_ctx); ContextSize = NativeApi.llama_n_ctx(_ctx);
} }
@ -210,7 +209,7 @@ namespace LLama
if(Utils.Eval(_ctx, tokens, i, n_eval, pastTokensCount, Params.Threads) != 0) if(Utils.Eval(_ctx, tokens, i, n_eval, pastTokensCount, Params.Threads) != 0)
{ {
_logger.Error($"Failed to eval."); _logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error);
throw new RuntimeError("Failed to eval."); throw new RuntimeError("Failed to eval.");
} }

View File

@ -1,5 +1,4 @@
using LLama.Exceptions; using LLama.Exceptions;
using LLama.Types;
using LLama.Extensions; using LLama.Extensions;
using LLama.Native; using LLama.Native;
using System; using System;
@ -8,6 +7,7 @@ using System.Diagnostics;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using LLama.Common;
namespace LLama.OldVersion namespace LLama.OldVersion
{ {
@ -184,12 +184,12 @@ namespace LLama.OldVersion
{ {
if (verbose) if (verbose)
{ {
LLamaLogger.Default.Info($"Attempting to load saved session from '{_path_session}'"); LLamaDefaultLogger.Default.Info($"Attempting to load saved session from '{_path_session}'");
} }
if (!File.Exists(_path_session)) if (!File.Exists(_path_session))
{ {
LLamaLogger.Default.Warn("Session file does not exist, will create."); LLamaDefaultLogger.Default.Warn("Session file does not exist, will create.");
} }
llama_token[] session_tokens = new llama_token[@params.n_ctx]; llama_token[] session_tokens = new llama_token[@params.n_ctx];
@ -201,7 +201,7 @@ namespace LLama.OldVersion
_session_tokens = session_tokens.Take((int)n_token_count_out).ToList(); _session_tokens = session_tokens.Take((int)n_token_count_out).ToList();
if (verbose) if (verbose)
{ {
LLamaLogger.Default.Info($"Loaded a session with prompt size of {_session_tokens.Count} tokens"); LLamaDefaultLogger.Default.Info($"Loaded a session with prompt size of {_session_tokens.Count} tokens");
} }
} }
@ -231,39 +231,39 @@ namespace LLama.OldVersion
if (_params.verbose_prompt) if (_params.verbose_prompt)
{ {
LLamaLogger.Default.Info("\n"); LLamaDefaultLogger.Default.Info("\n");
LLamaLogger.Default.Info($"prompt: '{_params.prompt}'"); LLamaDefaultLogger.Default.Info($"prompt: '{_params.prompt}'");
LLamaLogger.Default.Info($"number of tokens in prompt = {_embed_inp.Count}"); LLamaDefaultLogger.Default.Info($"number of tokens in prompt = {_embed_inp.Count}");
for (int i = 0; i < _embed_inp.Count; i++) for (int i = 0; i < _embed_inp.Count; i++)
{ {
LLamaLogger.Default.Info($"{_embed_inp[i]} -> '{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}'"); LLamaDefaultLogger.Default.Info($"{_embed_inp[i]} -> '{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}'");
} }
if (_params.n_keep > 0) if (_params.n_keep > 0)
{ {
LLamaLogger.Default.Info($"static prompt based on n_keep: '"); LLamaDefaultLogger.Default.Info($"static prompt based on n_keep: '");
for (int i = 0; i < _params.n_keep; i++) for (int i = 0; i < _params.n_keep; i++)
{ {
LLamaLogger.Default.Info($"{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}"); LLamaDefaultLogger.Default.Info($"{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}");
} }
LLamaLogger.Default.Info("\n"); LLamaDefaultLogger.Default.Info("\n");
} }
LLamaLogger.Default.Info("\n"); LLamaDefaultLogger.Default.Info("\n");
} }
if (_params.interactive && verbose) if (_params.interactive && verbose)
{ {
LLamaLogger.Default.Info("interactive mode on."); LLamaDefaultLogger.Default.Info("interactive mode on.");
} }
if (verbose) if (verbose)
{ {
LLamaLogger.Default.Info($"sampling: repeat_last_n = {_params.repeat_last_n}, " + LLamaDefaultLogger.Default.Info($"sampling: repeat_last_n = {_params.repeat_last_n}, " +
$"repeat_penalty = {_params.repeat_penalty}, presence_penalty = {_params.presence_penalty}, " + $"repeat_penalty = {_params.repeat_penalty}, presence_penalty = {_params.presence_penalty}, " +
$"frequency_penalty = {_params.frequency_penalty}, top_k = {_params.top_k}, tfs_z = {_params.tfs_z}," + $"frequency_penalty = {_params.frequency_penalty}, top_k = {_params.top_k}, tfs_z = {_params.tfs_z}," +
$" top_p = {_params.top_p}, typical_p = {_params.typical_p}, temp = {_params.temp}, mirostat = {_params.mirostat}," + $" top_p = {_params.top_p}, typical_p = {_params.typical_p}, temp = {_params.temp}, mirostat = {_params.mirostat}," +
$" mirostat_lr = {_params.mirostat_eta}, mirostat_ent = {_params.mirostat_tau}"); $" mirostat_lr = {_params.mirostat_eta}, mirostat_ent = {_params.mirostat_tau}");
LLamaLogger.Default.Info($"generate: n_ctx = {_n_ctx}, n_batch = {_params.n_batch}, n_predict = {_params.n_predict}, " + LLamaDefaultLogger.Default.Info($"generate: n_ctx = {_n_ctx}, n_batch = {_params.n_batch}, n_predict = {_params.n_predict}, " +
$"n_keep = {_params.n_keep}"); $"n_keep = {_params.n_keep}");
LLamaLogger.Default.Info("\n"); LLamaDefaultLogger.Default.Info("\n");
} }
_last_n_tokens = Enumerable.Repeat(0, _n_ctx).ToList(); _last_n_tokens = Enumerable.Repeat(0, _n_ctx).ToList();
@ -272,7 +272,7 @@ namespace LLama.OldVersion
{ {
if (verbose) if (verbose)
{ {
LLamaLogger.Default.Info("== Running in interactive mode. =="); LLamaDefaultLogger.Default.Info("== Running in interactive mode. ==");
} }
_is_interacting = _params.interactive_first; _is_interacting = _params.interactive_first;
} }
@ -316,16 +316,16 @@ namespace LLama.OldVersion
} }
if (n_matching_session_tokens >= (ulong)_embed_inp.Count) if (n_matching_session_tokens >= (ulong)_embed_inp.Count)
{ {
LLamaLogger.Default.Info("Session file has exact match for prompt!"); LLamaDefaultLogger.Default.Info("Session file has exact match for prompt!");
} }
else if (n_matching_session_tokens < (ulong)(_embed_inp.Count / 2)) else if (n_matching_session_tokens < (ulong)(_embed_inp.Count / 2))
{ {
LLamaLogger.Default.Warn($"session file has low similarity to prompt ({n_matching_session_tokens} " + LLamaDefaultLogger.Default.Warn($"session file has low similarity to prompt ({n_matching_session_tokens} " +
$"/ {_embed_inp.Count} tokens); will mostly be reevaluated."); $"/ {_embed_inp.Count} tokens); will mostly be reevaluated.");
} }
else else
{ {
LLamaLogger.Default.Info($"Session file matches {n_matching_session_tokens} / {_embed_inp.Count} " + LLamaDefaultLogger.Default.Info($"Session file matches {n_matching_session_tokens} / {_embed_inp.Count} " +
$"tokens of prompt."); $"tokens of prompt.");
} }
} }
@ -511,7 +511,7 @@ namespace LLama.OldVersion
{ {
if (_verbose) if (_verbose)
{ {
LLamaLogger.Default.Warn("In interacting when calling the model, automatically changed it."); LLamaDefaultLogger.Default.Warn("In interacting when calling the model, automatically changed it.");
} }
_is_interacting = false; _is_interacting = false;
} }
@ -581,7 +581,7 @@ namespace LLama.OldVersion
var array = _embed.Skip(i).ToArray(); var array = _embed.Skip(i).ToArray();
if (NativeApi.llama_eval(_ctx, array, n_eval, _n_past, _params.n_threads) != 0) if (NativeApi.llama_eval(_ctx, array, n_eval, _n_past, _params.n_threads) != 0)
{ {
LLamaLogger.Default.Error($"Failed to eval."); LLamaDefaultLogger.Default.Error($"Failed to eval.");
throw new RuntimeError("Failed to eval."); throw new RuntimeError("Failed to eval.");
} }
@ -776,7 +776,7 @@ namespace LLama.OldVersion
} }
else else
{ {
LLamaLogger.Default.Info(" [end of text]"); LLamaDefaultLogger.Default.Info(" [end of text]");
} }
} }
@ -790,7 +790,7 @@ namespace LLama.OldVersion
if (!string.IsNullOrEmpty(_path_session) && _params.prompt_cache_all) if (!string.IsNullOrEmpty(_path_session) && _params.prompt_cache_all)
{ {
LLamaLogger.Default.Info($"saving final output to session file {_path_session}"); LLamaDefaultLogger.Default.Info($"saving final output to session file {_path_session}");
var session_token_array = _session_tokens.ToArray(); var session_token_array = _session_tokens.ToArray();
NativeApi.llama_save_session_file(_ctx, _path_session, session_token_array, (ulong)session_token_array.Length); NativeApi.llama_save_session_file(_ctx, _path_session, session_token_array, (ulong)session_token_array.Length);
} }