diff --git a/LLama/Grammar/GrammarParser.cs b/LLama/Grammar/GrammarParser.cs index fca6e4f3..1a990ffb 100644 --- a/LLama/Grammar/GrammarParser.cs +++ b/LLama/Grammar/GrammarParser.cs @@ -1,6 +1,7 @@ using LLama.Native; using System; using System.Collections.Generic; +using System.Text; namespace LLama.Grammar { @@ -14,12 +15,12 @@ namespace LLama.Grammar { // NOTE: assumes valid utf8 (but checks for overrun) // copied from llama.cpp - public Tuple> DecodeUTF8(ReadOnlyMemory src) + public (uint, ReadOnlyMemory) DecodeUTF8(ReadOnlyMemory src) { - ReadOnlySpan span = src.Span; + ReadOnlySpan span = src.Span; int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; - byte firstByte = (byte)span[0]; + byte firstByte = span[0]; byte highbits = (byte)(firstByte >> 4); int len = lookup[highbits]; byte mask = (byte)((1 << (8 - len)) - 1); @@ -30,15 +31,15 @@ namespace LLama.Grammar for (; pos < end && pos < src.Length; pos++) { - value = (uint)((value << 6) + ((byte)span[pos] & 0x3F)); + value = (uint)((value << 6) + (span[pos] & 0x3F)); } - ReadOnlyMemory nextSpan = src.Slice(pos); + ReadOnlyMemory nextSpan = src.Slice(pos); - return new Tuple>(value, nextSpan); + return (value, nextSpan); } - public uint GetSymbolId(ParseState state, ReadOnlySpan src, int len) + public uint GetSymbolId(ParseState state, ReadOnlySpan src, int len) { uint nextId = (uint)state.SymbolIds.Count; string key = src.Slice(0, len).ToString(); @@ -72,23 +73,23 @@ namespace LLama.Grammar state.Rules[(int)ruleId] = rule; } - public bool IsWordChar(char c) + public bool IsWordChar(byte c) { return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); } - public Tuple> ParseHex(ReadOnlyMemory src, int size) + public (uint, ReadOnlyMemory) ParseHex(ReadOnlyMemory src, int size) { int pos = 0; int end = size; uint value = 0; - ReadOnlySpan srcSpan = src.Span; + ReadOnlySpan srcSpan = src.Span; for (; pos < end && pos < src.Length; pos++) { value <<= 4; - char c = srcSpan[pos]; + byte c = srcSpan[pos]; if ('a' <= c && c <= 'f') { value += (uint)(c - 'a' + 10); @@ -109,13 +110,13 @@ namespace LLama.Grammar if (pos != end) { - throw new InvalidOperationException($"Expecting {size} hex chars at {src.ToString()}"); + throw new InvalidOperationException($"Expecting {size} hex chars at {src}"); } - return new Tuple>(value, src.Slice(pos)); + return (value, src.Slice(pos)); } - public ReadOnlySpan ParseSpace(ReadOnlySpan src, bool newlineOk) + public ReadOnlySpan ParseSpace(ReadOnlySpan src, bool newlineOk) { int pos = 0; while (pos < src.Length && @@ -137,7 +138,7 @@ namespace LLama.Grammar return src.Slice(pos); } - public ReadOnlySpan ParseName(ReadOnlySpan src) + public ReadOnlySpan ParseName(ReadOnlySpan src) { int pos = 0; while (pos < src.Length && IsWordChar(src[pos])) @@ -151,13 +152,13 @@ namespace LLama.Grammar return src.Slice(pos); } - public Tuple> ParseChar(ReadOnlyMemory src) + public (uint, ReadOnlyMemory) ParseChar(ReadOnlyMemory src) { - ReadOnlySpan span = src.Span; + ReadOnlySpan span = src.Span; if (span[0] == '\\') { - switch (span[1]) + switch ((char)span[1]) { case 'x': return ParseHex(src.Slice(2), 2); @@ -166,16 +167,16 @@ namespace LLama.Grammar case 'U': return ParseHex(src.Slice(2), 8); case 't': - return new Tuple>('\t', src.Slice(2)); + return ('\t', src.Slice(2)); case 'r': - return new Tuple>('\r', src.Slice(2)); + return ('\r', src.Slice(2)); case 'n': - return new Tuple>('\n', src.Slice(2)); + return ('\n', src.Slice(2)); case '\\': case '"': case '[': case ']': - return new Tuple>(span[1], src.Slice(2)); + return (span[1], src.Slice(2)); default: throw new Exception("Unknown escape at " + src.ToString()); } @@ -188,9 +189,9 @@ namespace LLama.Grammar throw new Exception("Unexpected end of input"); } - public ReadOnlySpan ParseSequence( + public ReadOnlySpan ParseSequence( ref ParseState state, - ReadOnlyMemory src, + ReadOnlyMemory src, string ruleName, List outElements, bool isNested) @@ -262,7 +263,7 @@ namespace LLama.Grammar outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); if (pos[0] != ')') { - throw new Exception($"Expecting ')' at {new string(pos.ToArray())}"); + throw new Exception($"Expecting ')' at {Encoding.UTF8.GetString(pos.ToArray())}"); } pos = ParseSpace(pos.Slice(1), isNested); } @@ -270,7 +271,7 @@ namespace LLama.Grammar { if (lastSymStart == outElements.Count) { - throw new Exception($"Expecting preceding item to */+/? at {new string(pos.ToArray())}"); + throw new Exception($"Expecting preceding item to */+/? at {Encoding.UTF8.GetString(pos.ToArray())}"); } // apply transformation to previous symbol (lastSymStart to end) according to @@ -320,7 +321,7 @@ namespace LLama.Grammar return pos; } - public ReadOnlySpan ParseAlternates(ParseState state, ReadOnlySpan pos, string ruleName, uint subRuleId, bool v) + public ReadOnlySpan ParseAlternates(ParseState state, ReadOnlySpan pos, string ruleName, uint subRuleId, bool v) { throw new NotImplementedException(); }