Address PR change requests

This commit is contained in:
Mihai 2023-08-30 09:24:08 +03:00
parent 7f31276bdf
commit 8b4ec6d973
1 changed files with 28 additions and 27 deletions

View File

@ -1,6 +1,7 @@
using LLama.Native; using LLama.Native;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text;
namespace LLama.Grammar namespace LLama.Grammar
{ {
@ -14,12 +15,12 @@ namespace LLama.Grammar
{ {
// NOTE: assumes valid utf8 (but checks for overrun) // NOTE: assumes valid utf8 (but checks for overrun)
// copied from llama.cpp // copied from llama.cpp
public Tuple<uint, ReadOnlyMemory<char>> DecodeUTF8(ReadOnlyMemory<char> src) public (uint, ReadOnlyMemory<byte>) DecodeUTF8(ReadOnlyMemory<byte> src)
{ {
ReadOnlySpan<char> span = src.Span; ReadOnlySpan<byte> span = src.Span;
int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
byte firstByte = (byte)span[0]; byte firstByte = span[0];
byte highbits = (byte)(firstByte >> 4); byte highbits = (byte)(firstByte >> 4);
int len = lookup[highbits]; int len = lookup[highbits];
byte mask = (byte)((1 << (8 - len)) - 1); byte mask = (byte)((1 << (8 - len)) - 1);
@ -30,15 +31,15 @@ namespace LLama.Grammar
for (; pos < end && pos < src.Length; pos++) for (; pos < end && pos < src.Length; pos++)
{ {
value = (uint)((value << 6) + ((byte)span[pos] & 0x3F)); value = (uint)((value << 6) + (span[pos] & 0x3F));
} }
ReadOnlyMemory<char> nextSpan = src.Slice(pos); ReadOnlyMemory<byte> nextSpan = src.Slice(pos);
return new Tuple<uint, ReadOnlyMemory<char>>(value, nextSpan); return (value, nextSpan);
} }
public uint GetSymbolId(ParseState state, ReadOnlySpan<char> src, int len) public uint GetSymbolId(ParseState state, ReadOnlySpan<byte> src, int len)
{ {
uint nextId = (uint)state.SymbolIds.Count; uint nextId = (uint)state.SymbolIds.Count;
string key = src.Slice(0, len).ToString(); string key = src.Slice(0, len).ToString();
@ -72,23 +73,23 @@ namespace LLama.Grammar
state.Rules[(int)ruleId] = rule; state.Rules[(int)ruleId] = rule;
} }
public bool IsWordChar(char c) public bool IsWordChar(byte c)
{ {
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
} }
public Tuple<uint, ReadOnlyMemory<char>> ParseHex(ReadOnlyMemory<char> src, int size) public (uint, ReadOnlyMemory<byte>) ParseHex(ReadOnlyMemory<byte> src, int size)
{ {
int pos = 0; int pos = 0;
int end = size; int end = size;
uint value = 0; uint value = 0;
ReadOnlySpan<char> srcSpan = src.Span; ReadOnlySpan<byte> srcSpan = src.Span;
for (; pos < end && pos < src.Length; pos++) for (; pos < end && pos < src.Length; pos++)
{ {
value <<= 4; value <<= 4;
char c = srcSpan[pos]; byte c = srcSpan[pos];
if ('a' <= c && c <= 'f') if ('a' <= c && c <= 'f')
{ {
value += (uint)(c - 'a' + 10); value += (uint)(c - 'a' + 10);
@ -109,13 +110,13 @@ namespace LLama.Grammar
if (pos != end) if (pos != end)
{ {
throw new InvalidOperationException($"Expecting {size} hex chars at {src.ToString()}"); throw new InvalidOperationException($"Expecting {size} hex chars at {src}");
} }
return new Tuple<uint, ReadOnlyMemory<char>>(value, src.Slice(pos)); return (value, src.Slice(pos));
} }
public ReadOnlySpan<char> ParseSpace(ReadOnlySpan<char> src, bool newlineOk) public ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk)
{ {
int pos = 0; int pos = 0;
while (pos < src.Length && while (pos < src.Length &&
@ -137,7 +138,7 @@ namespace LLama.Grammar
return src.Slice(pos); return src.Slice(pos);
} }
public ReadOnlySpan<char> ParseName(ReadOnlySpan<char> src) public ReadOnlySpan<byte> ParseName(ReadOnlySpan<byte> src)
{ {
int pos = 0; int pos = 0;
while (pos < src.Length && IsWordChar(src[pos])) while (pos < src.Length && IsWordChar(src[pos]))
@ -151,13 +152,13 @@ namespace LLama.Grammar
return src.Slice(pos); return src.Slice(pos);
} }
public Tuple<uint, ReadOnlyMemory<char>> ParseChar(ReadOnlyMemory<char> src) public (uint, ReadOnlyMemory<byte>) ParseChar(ReadOnlyMemory<byte> src)
{ {
ReadOnlySpan<char> span = src.Span; ReadOnlySpan<byte> span = src.Span;
if (span[0] == '\\') if (span[0] == '\\')
{ {
switch (span[1]) switch ((char)span[1])
{ {
case 'x': case 'x':
return ParseHex(src.Slice(2), 2); return ParseHex(src.Slice(2), 2);
@ -166,16 +167,16 @@ namespace LLama.Grammar
case 'U': case 'U':
return ParseHex(src.Slice(2), 8); return ParseHex(src.Slice(2), 8);
case 't': case 't':
return new Tuple<uint, ReadOnlyMemory<char>>('\t', src.Slice(2)); return ('\t', src.Slice(2));
case 'r': case 'r':
return new Tuple<uint, ReadOnlyMemory<char>>('\r', src.Slice(2)); return ('\r', src.Slice(2));
case 'n': case 'n':
return new Tuple<uint, ReadOnlyMemory<char>>('\n', src.Slice(2)); return ('\n', src.Slice(2));
case '\\': case '\\':
case '"': case '"':
case '[': case '[':
case ']': case ']':
return new Tuple<uint, ReadOnlyMemory<char>>(span[1], src.Slice(2)); return (span[1], src.Slice(2));
default: default:
throw new Exception("Unknown escape at " + src.ToString()); throw new Exception("Unknown escape at " + src.ToString());
} }
@ -188,9 +189,9 @@ namespace LLama.Grammar
throw new Exception("Unexpected end of input"); throw new Exception("Unexpected end of input");
} }
public ReadOnlySpan<char> ParseSequence( public ReadOnlySpan<byte> ParseSequence(
ref ParseState state, ref ParseState state,
ReadOnlyMemory<char> src, ReadOnlyMemory<byte> src,
string ruleName, string ruleName,
List<LLamaGrammarElement> outElements, List<LLamaGrammarElement> outElements,
bool isNested) bool isNested)
@ -262,7 +263,7 @@ namespace LLama.Grammar
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId });
if (pos[0] != ')') if (pos[0] != ')')
{ {
throw new Exception($"Expecting ')' at {new string(pos.ToArray())}"); throw new Exception($"Expecting ')' at {Encoding.UTF8.GetString(pos.ToArray())}");
} }
pos = ParseSpace(pos.Slice(1), isNested); pos = ParseSpace(pos.Slice(1), isNested);
} }
@ -270,7 +271,7 @@ namespace LLama.Grammar
{ {
if (lastSymStart == outElements.Count) if (lastSymStart == outElements.Count)
{ {
throw new Exception($"Expecting preceding item to */+/? at {new string(pos.ToArray())}"); throw new Exception($"Expecting preceding item to */+/? at {Encoding.UTF8.GetString(pos.ToArray())}");
} }
// apply transformation to previous symbol (lastSymStart to end) according to // apply transformation to previous symbol (lastSymStart to end) according to
@ -320,7 +321,7 @@ namespace LLama.Grammar
return pos; return pos;
} }
public ReadOnlySpan<char> ParseAlternates(ParseState state, ReadOnlySpan<char> pos, string ruleName, uint subRuleId, bool v) public ReadOnlySpan<byte> ParseAlternates(ParseState state, ReadOnlySpan<byte> pos, string ruleName, uint subRuleId, bool v)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }