loadTransforms flag for LoadSession methods

This commit is contained in:
eublefar 2024-03-21 12:18:38 +01:00
parent 9440f153da
commit b8cd5b7ee5
1 changed files with 15 additions and 11 deletions

View File

@ -178,17 +178,17 @@ public class ChatSession
/// Load a session from a session state.
/// </summary>
/// <param name="state"></param>
/// <param name="loadTransforms">If true loads transforms saved in the session state.</param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public void LoadSession(SessionState state)
public void LoadSession(SessionState state, bool loadTransforms = true)
{
if (Executor is StatefulExecutorBase statefulExecutor)
{
statefulExecutor.LoadState(state.ExecutorState);
}
else
{
throw new ArgumentException("Executor must be a StatefulExecutorBase to support loading of session state", nameof(state));
if (state.ExecutorState is not null)
{
statefulExecutor.LoadState(state.ExecutorState);
}
}
if (state.ContextState is null)
{
@ -199,18 +199,22 @@ public class ChatSession
Executor.Context.LoadState(state.ContextState);
}
History = new ChatHistory(state.History);
InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList();
OutputTransform = state.OutputTransform.Clone();
HistoryTransform = state.HistoryTransform.Clone();
if (loadTransforms)
{
InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList();
OutputTransform = state.OutputTransform.Clone();
HistoryTransform = state.HistoryTransform.Clone();
}
}
/// <summary>
/// Load a session from a directory.
/// </summary>
/// <param name="path"></param>
/// <param name="loadTransforms">If true loads transforms saved in the session state.</param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public void LoadSession(string path)
public void LoadSession(string path, bool loadTransforms = true)
{
var state = SessionState.Load(path);
// Handle non-polymorphic serialization of executor state
@ -219,7 +223,7 @@ public class ChatSession
var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME);
((StatefulExecutorBase) Executor).LoadState(filename: executorPath);
}
LoadSession(state);
LoadSession(state, loadTransforms);
}
/// <summary>