diff --git a/LLama.KernelMemory/BuilderExtensions.cs b/LLama.KernelMemory/BuilderExtensions.cs
index 27ede7aa..84e29862 100644
--- a/LLama.KernelMemory/BuilderExtensions.cs
+++ b/LLama.KernelMemory/BuilderExtensions.cs
@@ -8,6 +8,7 @@ using LLama;
using LLama.Common;
using Microsoft.KernelMemory.AI;
using Microsoft.SemanticKernel.AI.Embeddings;
+using LLama.Native;
namespace LLamaSharp.KernelMemory
{
@@ -81,7 +82,9 @@ namespace LLamaSharp.KernelMemory
ContextSize = config?.ContextSize ?? 2048,
Seed = config?.Seed ?? 0,
GpuLayerCount = config?.GpuLayerCount ?? 20,
- EmbeddingMode = true
+ EmbeddingMode = true,
+ MainGpu = config?.MainGpu ?? 0,
+ SplitMode = config?.SplitMode ?? GPUSplitMode.None
};
var weights = LLamaWeights.LoadFromFile(parameters);
var context = weights.CreateContext(parameters);
diff --git a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs
index 0451d5bf..4a089fe4 100644
--- a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs
+++ b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs
@@ -27,7 +27,12 @@ namespace LLamaSharp.KernelMemory
public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config)
{
this._config = config;
- var @params = new ModelParams(_config.ModelPath) { EmbeddingMode = true };
+ var @params = new ModelParams(_config.ModelPath)
+ {
+ EmbeddingMode = true,
+ MainGpu = _config.MainGpu,
+ SplitMode = _config.SplitMode
+ };
_weights = LLamaWeights.LoadFromFile(@params);
_embedder = new LLamaEmbedder(_weights, @params);
_ownsWeights = true;
@@ -42,7 +47,12 @@ namespace LLamaSharp.KernelMemory
public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights weights)
{
this._config = config;
- var @params = new ModelParams(_config.ModelPath) { EmbeddingMode = true };
+ var @params = new ModelParams(_config.ModelPath)
+ {
+ EmbeddingMode = true,
+ MainGpu = _config.MainGpu,
+ SplitMode = _config.SplitMode
+ };
_weights = weights;
_embedder = new LLamaEmbedder(_weights, @params);
_ownsEmbedder = true;
diff --git a/LLama.KernelMemory/LlamaSharpConfig.cs b/LLama.KernelMemory/LlamaSharpConfig.cs
index 9299759e..e5fc4bf1 100644
--- a/LLama.KernelMemory/LlamaSharpConfig.cs
+++ b/LLama.KernelMemory/LlamaSharpConfig.cs
@@ -1,4 +1,5 @@
using LLama.Common;
+using LLama.Native;
using System;
using System.Collections.Generic;
using System.Linq;
@@ -41,6 +42,31 @@ namespace LLamaSharp.KernelMemory
///
public int? GpuLayerCount { get; set; }
+ ///
+ /// main_gpu interpretation depends on split_mode:
+ ///
+ /// -
+ /// None
+ /// The GPU that is used for the entire mode.
+ ///
+ /// -
+ /// Row
+ /// The GPU that is used for small tensors and intermediate results.
+ ///
+ /// -
+ /// Layer
+ /// Ignored.
+ ///
+ ///
+ ///
+ ///
+ public int MainGpu { get; set; } = 0;
+
+ ///
+ /// How to split the model across multiple GPUs
+ ///
+ ///
+ public GPUSplitMode SplitMode { get; set; } = GPUSplitMode.None;
///
/// Set the default inference parameters.
diff --git a/LLama.KernelMemory/LlamaSharpTextGenerator.cs b/LLama.KernelMemory/LlamaSharpTextGenerator.cs
index 7269152b..de6373ee 100644
--- a/LLama.KernelMemory/LlamaSharpTextGenerator.cs
+++ b/LLama.KernelMemory/LlamaSharpTextGenerator.cs
@@ -1,6 +1,7 @@
using LLama;
using LLama.Abstractions;
using LLama.Common;
+using LLama.Native;
using Microsoft.KernelMemory.AI;
using System;
using System.Collections.Generic;
@@ -34,7 +35,9 @@ namespace LLamaSharp.KernelMemory
{
ContextSize = config?.ContextSize ?? 2048,
Seed = config?.Seed ?? 0,
- GpuLayerCount = config?.GpuLayerCount ?? 20
+ GpuLayerCount = config?.GpuLayerCount ?? 20,
+ MainGpu = config?.MainGpu ?? 0,
+ SplitMode = config?.SplitMode ?? GPUSplitMode.None
};
_weights = LLamaWeights.LoadFromFile(parameters);
_context = _weights.CreateContext(parameters);