Address PR change requests

This commit is contained in:
Mihai 2023-08-30 09:24:08 +03:00
parent 7f31276bdf
commit 8b4ec6d973
1 changed files with 28 additions and 27 deletions

View File

@ -1,6 +1,7 @@
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Text;
namespace LLama.Grammar
{
@ -14,12 +15,12 @@ namespace LLama.Grammar
{
// NOTE: assumes valid utf8 (but checks for overrun)
// copied from llama.cpp
public Tuple<uint, ReadOnlyMemory<char>> DecodeUTF8(ReadOnlyMemory<char> src)
public (uint, ReadOnlyMemory<byte>) DecodeUTF8(ReadOnlyMemory<byte> src)
{
ReadOnlySpan<char> span = src.Span;
ReadOnlySpan<byte> span = src.Span;
int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
byte firstByte = (byte)span[0];
byte firstByte = span[0];
byte highbits = (byte)(firstByte >> 4);
int len = lookup[highbits];
byte mask = (byte)((1 << (8 - len)) - 1);
@ -30,15 +31,15 @@ namespace LLama.Grammar
for (; pos < end && pos < src.Length; pos++)
{
value = (uint)((value << 6) + ((byte)span[pos] & 0x3F));
value = (uint)((value << 6) + (span[pos] & 0x3F));
}
ReadOnlyMemory<char> nextSpan = src.Slice(pos);
ReadOnlyMemory<byte> nextSpan = src.Slice(pos);
return new Tuple<uint, ReadOnlyMemory<char>>(value, nextSpan);
return (value, nextSpan);
}
public uint GetSymbolId(ParseState state, ReadOnlySpan<char> src, int len)
public uint GetSymbolId(ParseState state, ReadOnlySpan<byte> src, int len)
{
uint nextId = (uint)state.SymbolIds.Count;
string key = src.Slice(0, len).ToString();
@ -72,23 +73,23 @@ namespace LLama.Grammar
state.Rules[(int)ruleId] = rule;
}
public bool IsWordChar(char c)
public bool IsWordChar(byte c)
{
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
}
public Tuple<uint, ReadOnlyMemory<char>> ParseHex(ReadOnlyMemory<char> src, int size)
public (uint, ReadOnlyMemory<byte>) ParseHex(ReadOnlyMemory<byte> src, int size)
{
int pos = 0;
int end = size;
uint value = 0;
ReadOnlySpan<char> srcSpan = src.Span;
ReadOnlySpan<byte> srcSpan = src.Span;
for (; pos < end && pos < src.Length; pos++)
{
value <<= 4;
char c = srcSpan[pos];
byte c = srcSpan[pos];
if ('a' <= c && c <= 'f')
{
value += (uint)(c - 'a' + 10);
@ -109,13 +110,13 @@ namespace LLama.Grammar
if (pos != end)
{
throw new InvalidOperationException($"Expecting {size} hex chars at {src.ToString()}");
throw new InvalidOperationException($"Expecting {size} hex chars at {src}");
}
return new Tuple<uint, ReadOnlyMemory<char>>(value, src.Slice(pos));
return (value, src.Slice(pos));
}
public ReadOnlySpan<char> ParseSpace(ReadOnlySpan<char> src, bool newlineOk)
public ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk)
{
int pos = 0;
while (pos < src.Length &&
@ -137,7 +138,7 @@ namespace LLama.Grammar
return src.Slice(pos);
}
public ReadOnlySpan<char> ParseName(ReadOnlySpan<char> src)
public ReadOnlySpan<byte> ParseName(ReadOnlySpan<byte> src)
{
int pos = 0;
while (pos < src.Length && IsWordChar(src[pos]))
@ -151,13 +152,13 @@ namespace LLama.Grammar
return src.Slice(pos);
}
public Tuple<uint, ReadOnlyMemory<char>> ParseChar(ReadOnlyMemory<char> src)
public (uint, ReadOnlyMemory<byte>) ParseChar(ReadOnlyMemory<byte> src)
{
ReadOnlySpan<char> span = src.Span;
ReadOnlySpan<byte> span = src.Span;
if (span[0] == '\\')
{
switch (span[1])
switch ((char)span[1])
{
case 'x':
return ParseHex(src.Slice(2), 2);
@ -166,16 +167,16 @@ namespace LLama.Grammar
case 'U':
return ParseHex(src.Slice(2), 8);
case 't':
return new Tuple<uint, ReadOnlyMemory<char>>('\t', src.Slice(2));
return ('\t', src.Slice(2));
case 'r':
return new Tuple<uint, ReadOnlyMemory<char>>('\r', src.Slice(2));
return ('\r', src.Slice(2));
case 'n':
return new Tuple<uint, ReadOnlyMemory<char>>('\n', src.Slice(2));
return ('\n', src.Slice(2));
case '\\':
case '"':
case '[':
case ']':
return new Tuple<uint, ReadOnlyMemory<char>>(span[1], src.Slice(2));
return (span[1], src.Slice(2));
default:
throw new Exception("Unknown escape at " + src.ToString());
}
@ -188,9 +189,9 @@ namespace LLama.Grammar
throw new Exception("Unexpected end of input");
}
public ReadOnlySpan<char> ParseSequence(
public ReadOnlySpan<byte> ParseSequence(
ref ParseState state,
ReadOnlyMemory<char> src,
ReadOnlyMemory<byte> src,
string ruleName,
List<LLamaGrammarElement> outElements,
bool isNested)
@ -262,7 +263,7 @@ namespace LLama.Grammar
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId });
if (pos[0] != ')')
{
throw new Exception($"Expecting ')' at {new string(pos.ToArray())}");
throw new Exception($"Expecting ')' at {Encoding.UTF8.GetString(pos.ToArray())}");
}
pos = ParseSpace(pos.Slice(1), isNested);
}
@ -270,7 +271,7 @@ namespace LLama.Grammar
{
if (lastSymStart == outElements.Count)
{
throw new Exception($"Expecting preceding item to */+/? at {new string(pos.ToArray())}");
throw new Exception($"Expecting preceding item to */+/? at {Encoding.UTF8.GetString(pos.ToArray())}");
}
// apply transformation to previous symbol (lastSymStart to end) according to
@ -320,7 +321,7 @@ namespace LLama.Grammar
return pos;
}
public ReadOnlySpan<char> ParseAlternates(ParseState state, ReadOnlySpan<char> pos, string ruleName, uint subRuleId, bool v)
public ReadOnlySpan<byte> ParseAlternates(ParseState state, ReadOnlySpan<byte> pos, string ruleName, uint subRuleId, bool v)
{
throw new NotImplementedException();
}