refactor: use official api of quantization instead.

This commit is contained in:
Yaohui Liu 2023-05-13 15:02:19 +08:00
parent 5e378f9a52
commit 6ffcb5306b
No known key found for this signature in database
GPG Key ID: E86D01E1809BD23E
5 changed files with 56 additions and 24 deletions

View File

@ -25,7 +25,7 @@ else if (choice == 3) // quantization
q.Run(@"<Your src model file path>",
@"<Your dst model file path>", "q4_1");
}
else if (choice == 4) // quantization
else if (choice == 4) // get the embeddings only
{
GetEmbeddings em = new GetEmbeddings(@"<Your model file path>");
em.Run("Hello, what is python?");

View File

@ -13,9 +13,9 @@ namespace LLama.Examples
}
public void Run(string srcFileName, string dstFilename, string ftype, int nthread = 0, bool printInfo = true)
public void Run(string srcFileName, string dstFilename, string ftype, int nthread = -1)
{
if(Quantizer.Quantize(srcFileName, dstFilename, ftype, nthread, printInfo))
if(Quantizer.Quantize(srcFileName, dstFilename, ftype, nthread))
{
Console.WriteLine("Quantization succeed!");
}

View File

@ -7,8 +7,16 @@ namespace LLama.Native
{
internal partial class NativeApi
{
/// <summary>
/// Returns 0 on success
/// </summary>
/// <param name="fname_inp"></param>
/// <param name="fname_out"></param>
/// <param name="ftype"></param>
/// <param name="nthread">how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given</param>
/// <remarks>not great API - very likely to change</remarks>
/// <returns>Returns 0 on success</returns>
[DllImport(libraryName)]
public static extern bool ggml_custom_quantize(string src_filename, string dst_filename,
string ftype_str, int nthread, bool print_info);
public static extern int llama_model_quantize(string fname_inp, string fname_out, LLamaFtype ftype, int nthread);
}
}

View File

@ -42,18 +42,6 @@ namespace LLama.Native
[DllImport(libraryName)]
public static extern void llama_free(IntPtr ctx);
/// <summary>
/// Returns 0 on success
/// </summary>
/// <param name="fname_inp"></param>
/// <param name="fname_out"></param>
/// <param name="ftype"></param>
/// <param name="nthread">how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given</param>
/// <remarks>not great API - very likely to change</remarks>
/// <returns>Returns 0 on success</returns>
[DllImport(libraryName)]
public static extern int llama_model_quantize(string fname_inp, string fname_out, LLamaFtype ftype, int nthread);
/// <summary>
/// Apply a LoRA adapter to a loaded model
/// path_base_model is the path to a higher quality model to use as a base for

View File

@ -8,20 +8,37 @@ namespace LLama
{
public class Quantizer
{
public static bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = 0, bool printInfo = true)
{
return Quantize(srcFileName, dstFilename, FtypeToString(ftype), nthread, printInfo);
}
public static bool Quantize(string srcFileName, string dstFilename, string ftype, int nthread = 0, bool printInfo = true)
/// <summary>
/// Quantize the model.
/// </summary>
/// <param name="srcFileName">The model file to be quantized.</param>
/// <param name="dstFilename">The path to save the quantized model.</param>
/// <param name="ftype">The type of quantization.</param>
/// <param name="nthread">Thread to be used during the quantization. By default it's the physical core number.</param>
/// <returns>Whether the quantization is successful.</returns>
/// <exception cref="ArgumentException"></exception>
public static bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1)
{
if (!ValidateFtype(ftype))
{
throw new ArgumentException($"The type {Enum.GetName(typeof(LLamaFtype), ftype)} is not a valid type " +
$"to perform quantization.");
}
return NativeApi.llama_model_quantize(srcFileName, dstFilename, ftype, nthread) == 0;
}
return NativeApi.ggml_custom_quantize(srcFileName, dstFilename, ftype, nthread, printInfo);
/// <summary>
/// Quantize the model.
/// </summary>
/// <param name="srcFileName">The model file to be quantized.</param>
/// <param name="dstFilename">The path to save the quantized model.</param>
/// <param name="ftype">The type of quantization.</param>
/// <param name="nthread">Thread to be used during the quantization. By default it's the physical core number.</param>
/// <returns>Whether the quantization is successful.</returns>
/// <exception cref="ArgumentException"></exception>
public static bool Quantize(string srcFileName, string dstFilename, string ftype, int nthread = -1)
{
return Quantize(srcFileName, dstFilename, StringToFtype(ftype), nthread);
}
private static bool ValidateFtype(string ftype)
@ -29,6 +46,12 @@ namespace LLama
return new string[] { "q4_0", "q4_1", "q4_2", "q5_0", "q5_1", "q8_0" }.Contains(ftype);
}
private static bool ValidateFtype(LLamaFtype ftype)
{
return ftype is LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_0 or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1 or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_2
or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_0 or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_1 or LLamaFtype.LLAMA_FTYPE_MOSTLY_Q8_0;
}
private static string FtypeToString(LLamaFtype ftype)
{
return ftype switch
@ -43,5 +66,18 @@ namespace LLama
$"to perform quantization.")
};
}
private static LLamaFtype StringToFtype(string str)
{
return str switch
{
"q4_0" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_0,
"q4_1" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1,
"q4_2" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_2,
"q5_0" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_0,
"q5_1" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_1,
"q8_0" => LLamaFtype.LLAMA_FTYPE_MOSTLY_Q8_0,
};
}
}
}