diff --git a/LLama.Examples/Program.cs b/LLama.Examples/Program.cs index 4a7bf587..19a4dd51 100644 --- a/LLama.Examples/Program.cs +++ b/LLama.Examples/Program.cs @@ -25,7 +25,7 @@ else if (choice == 3) // quantization q.Run(@"", @"", "q4_1"); } -else if (choice == 4) // quantization +else if (choice == 4) // get the embeddings only { GetEmbeddings em = new GetEmbeddings(@""); em.Run("Hello, what is python?"); diff --git a/LLama.Examples/Quantize.cs b/LLama.Examples/Quantize.cs index cf1f0419..03987cea 100644 --- a/LLama.Examples/Quantize.cs +++ b/LLama.Examples/Quantize.cs @@ -13,9 +13,9 @@ namespace LLama.Examples } - public void Run(string srcFileName, string dstFilename, string ftype, int nthread = 0, bool printInfo = true) + public void Run(string srcFileName, string dstFilename, string ftype, int nthread = -1) { - if(Quantizer.Quantize(srcFileName, dstFilename, ftype, nthread, printInfo)) + if(Quantizer.Quantize(srcFileName, dstFilename, ftype, nthread)) { Console.WriteLine("Quantization succeed!"); } diff --git a/LLama/Native/NativeApi.Quantize.cs b/LLama/Native/NativeApi.Quantize.cs index f91fc9d8..75e58096 100644 --- a/LLama/Native/NativeApi.Quantize.cs +++ b/LLama/Native/NativeApi.Quantize.cs @@ -7,8 +7,16 @@ namespace LLama.Native { internal partial class NativeApi { + /// + /// Returns 0 on success + /// + /// + /// + /// + /// how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given + /// not great API - very likely to change + /// Returns 0 on success [DllImport(libraryName)] - public static extern bool ggml_custom_quantize(string src_filename, string dst_filename, - string ftype_str, int nthread, bool print_info); + public static extern int llama_model_quantize(string fname_inp, string fname_out, LLamaFtype ftype, int nthread); } } diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 21b62a27..94d7504a 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -42,18 +42,6 @@ namespace LLama.Native [DllImport(libraryName)] public static extern void llama_free(IntPtr ctx); - /// - /// Returns 0 on success - /// - /// - /// - /// - /// how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given - /// not great API - very likely to change - /// Returns 0 on success - [DllImport(libraryName)] - public static extern int llama_model_quantize(string fname_inp, string fname_out, LLamaFtype ftype, int nthread); - /// /// Apply a LoRA adapter to a loaded model /// path_base_model is the path to a higher quality model to use as a base for diff --git a/LLama/Quantizer.cs b/LLama/Quantizer.cs index 42ba68b6..671aacf0 100644 --- a/LLama/Quantizer.cs +++ b/LLama/Quantizer.cs @@ -8,20 +8,37 @@ namespace LLama { public class Quantizer { - public static bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = 0, bool printInfo = true) - { - return Quantize(srcFileName, dstFilename, FtypeToString(ftype), nthread, printInfo); - } - - public static bool Quantize(string srcFileName, string dstFilename, string ftype, int nthread = 0, bool printInfo = true) + /// + /// Quantize the model. + /// + /// The model file to be quantized. + /// The path to save the quantized model. + /// The type of quantization. + /// Thread to be used during the quantization. By default it's the physical core number. + /// Whether the quantization is successful. + /// + public static bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1) { if (!ValidateFtype(ftype)) { throw new ArgumentException($"The type {Enum.GetName(typeof(LLamaFtype), ftype)} is not a valid type " + $"to perform quantization."); } + return NativeApi.llama_model_quantize(srcFileName, dstFilename, ftype, nthread) == 0; + } - return NativeApi.ggml_custom_quantize(srcFileName, dstFilename, ftype, nthread, printInfo); + /// + /// Quantize the model. + /// + /// The model file to be quantized. + /// The path to save the quantized model. + /// The type of quantization. + /// Thread to be used during the quantization. By default it's the physical core number. + /// Whether the quantization is successful. + /// + public static bool Quantize(string srcFileName, string dstFilename, string ftype, int nthread = -1) + { + return Quantize(srcFileName, dstFilename, StringToFtype(ftype), nthread); } private static bool ValidateFtype(string ftype) @@ -29,6 +46,12 @@ namespace LLama return new string[] { "q4_0", "q4_1", "q4_2", "q5_0", "q5_1", "q8_0" }.Contains(ftype); } + private static bool ValidateFtype(LLamaFtype ftype) + { + return ftype is LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_0 or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1 or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_2 + or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_0 or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_1 or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q8_0; + } + private static string FtypeToString(LLamaFtype ftype) { return ftype switch @@ -43,5 +66,18 @@ namespace LLama $"to perform quantization.") }; } + + private static LLamaFtype StringToFtype(string str) + { + return str switch + { + "q4_0" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_0, + "q4_1" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1, + "q4_2" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_2, + "q5_0" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_0, + "q5_1" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_1, + "q8_0" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q8_0, + }; + } } }