First working version of r2ai-native, still far from clean

This commit is contained in:
pancake 2023-12-20 21:13:51 +01:00
parent 5d87f1f174
commit 5405e69b65
2 changed files with 95 additions and 74 deletions

View File

@ -35,6 +35,7 @@
#endif
#define R2AI_MODEL_PATH "/tmp/mistral-7b-v0.1.Q2_K.gguf"
// #define R2AI_MODEL_PATH "/tmp/Utopia-13B.q5_k_m.gguf"
// Abuse globals to make things easier to refactor this shitty code
static gpt_params params; // SHOULDNT BE A GLOBAL
static std::string path_session;
@ -81,29 +82,28 @@ static void sigint_handler(int signo) {
#endif
static void null_log(ggml_log_level level, const char * text, void * user_data) {
(void) level;
(void) user_data;
// LOG_TEE("%s", text);
(void) level;
(void) user_data;
// LOG_TEE ("%s", text);
}
static bool cxxreadline(std::string & line, bool multiline_input) {
#if defined(_WIN32)
std::wstring wline;
if (!std::getline(std::wcin, wline)) {
if (!std::getline (std::wcin, wline)) {
// Input stream is bad or EOF received
line.clear();
GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0);
line.clear ();
GenerateConsoleCtrlEvent (CTRL_C_EVENT, 0);
return false;
}
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
line.resize(size_needed);
WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
int size_needed = WideCharToMultiByte (CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
line.resize (size_needed);
WideCharToMultiByte (CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
#else
if (!std::getline(std::cin, line)) {
if (!std::getline (std::cin, line)) {
// Input stream is bad or EOF received
stdin_borken = true;
line.clear();
line.clear ();
return false;
}
#endif
@ -129,7 +129,9 @@ bool main_r2ai_init(const char *model_path) {
params.input_suffix = "[/INST]";
params.antiprompt.push_back("[INST]");
params.antiprompt.push_back("</s>");
params.interactive = true;
params.interactive_first = true;
params.model = model_path; //
// params.prompt = "<s>"; // Act as a helpful assistant for radare2, your name is r2ai.";
/// params.prompt = "<s>[INST][/INST]"; // <s>[INST]Act as a helpful assistant for radare2, your name is r2ai[/INST]Sure!";
@ -150,10 +152,10 @@ bool main_r2ai_init(const char *model_path) {
}
#endif
g_params = &params;
llama_log_set(null_log, nullptr);
llama_log_set (null_log, nullptr);
// log_set_target(log_filename_generator("main", "log"));
// log_dump_cmdline (argc, argv);
sparams = &g_params->sparams;
sparams = &g_params->sparams;
// TODO: Dump params ?
//printf ("Params perplexity: %s\n", LOG_TOSTR(params.perplexity));
@ -188,9 +190,9 @@ bool main_r2ai_init(const char *model_path) {
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
}
LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
LOG_TEE("%s: seed = %u\n", __func__, g_params.seed);
LOG_TEE ("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
LOG_TEE ("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
LOG_TEE ("%s: seed = %u\n", __func__, g_params.seed);
#endif
#if 0
@ -251,7 +253,7 @@ bool main_r2ai_init(const char *model_path) {
// Should not run without any tokens
if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(model));
embd_inp.push_back (llama_token_bos(model));
printf ("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
}
@ -297,7 +299,7 @@ bool main_r2ai_init(const char *model_path) {
#endif
// remove any "future" tokens that we might have inherited from the previous session
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
llama_kv_cache_seq_rm (ctx, -1, n_matching_session_tokens, -1);
}
#if 0
LOGLN(
@ -428,7 +430,6 @@ static void main_r2ai_fini(void) {
int main_r2ai_message(const char *message) {
bool is_antiprompt = false;
bool input_echo = false;
int n_past = 0;
int n_remain = params.n_predict;
int n_consumed = 0;
@ -447,10 +448,11 @@ int main_r2ai_message(const char *message) {
ctx_sampling = llama_sampling_init (*sparams);
bool breakloop = false;
bool eos_found = false;
while (!breakloop && ( (n_remain != 0 && !is_antiprompt) || params.interactive)) {
while (( (n_remain != 0 && !is_antiprompt) || params.interactive)) {
// predict
if (!embd.empty()) {
if (!embd.empty ()) {
// Note: n_ctx - 4 here is to match the logic for commandline prompt handling via
// --prompt or --file which uses the same value.
int max_embd_size = n_ctx - 4;
@ -458,12 +460,12 @@ int main_r2ai_message(const char *message) {
// Ensure the input doesn't exceed the context size by truncating embd if necessary.
if ((int) embd.size() > max_embd_size) {
const int skipped_tokens = (int) embd.size() - max_embd_size;
embd.resize(max_embd_size);
embd.resize (max_embd_size);
// console::set_display(console::error);
printf("<<input too long: skipped %d token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
printf ("<<input too long: skipped %d token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
// console::set_display(console::reset);
fflush(stdout);
fflush (stdout);
}
// infinite text generation via context swapping
@ -477,13 +479,13 @@ int main_r2ai_message(const char *message) {
}
const int n_left = n_past - params.n_keep - 1;
const int n_discard = n_left/2;
const int n_discard = n_left / 2;
printf ("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift (ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
n_past -= n_discard;
@ -495,21 +497,19 @@ int main_r2ai_message(const char *message) {
printf ("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
printf ("clear session path\n");
#endif
path_session.clear();
path_session.clear ();
}
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
if (n_session_consumed < (int) session_tokens.size()) {
size_t i = 0;
for ( ; i < embd.size(); i++) {
for ( ; i < embd.size (); i++) {
if (embd[i] != session_tokens[n_session_consumed]) {
session_tokens.resize(n_session_consumed);
break;
}
n_past++;
n_session_consumed++;
if (n_session_consumed >= (int) session_tokens.size()) {
i++;
break;
@ -524,7 +524,7 @@ int main_r2ai_message(const char *message) {
// embd is typically prepared beforehand to fit within a batch, but not always
if (ctx_guidance) {
int input_size = 0;
llama_token * input_buf = NULL;
llama_token *input_buf = NULL;
if (n_past_guidance < (int) guidance_inp.size()) {
// Guidance context should have the same data with these modifications:
//
@ -550,8 +550,8 @@ int main_r2ai_message(const char *message) {
for (int i = 0; i < input_size; i += params.n_batch) {
int n_eval = std::min(input_size - i, params.n_batch);
if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
if (llama_decode (ctx_guidance, llama_batch_get_one (input_buf + i, n_eval, n_past_guidance, 0))) {
LOG_TEE ("%s : failed to eval\n", __func__);
return 1;
}
@ -565,26 +565,23 @@ int main_r2ai_message(const char *message) {
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
// printf ("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
// printf ("DECODE1 %d embd=%d\n", n_eval, embd[i]);
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
if (llama_decode (ctx, llama_batch_get_one (&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
n_past += n_eval;
// printf ("n_past = %d\n", n_past);
}
if (!embd.empty() && !path_session.empty()) {
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
session_tokens.insert (session_tokens.end(), embd.begin(), embd.end());
n_session_consumed = session_tokens.size();
}
}
embd.clear();
embd_guidance.clear();
embd.clear ();
embd_guidance.clear ();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
const llama_token id = llama_sampling_sample (ctx_sampling, ctx, ctx_guidance);
@ -593,7 +590,6 @@ int main_r2ai_message(const char *message) {
embd.push_back (id);
// echo this to console
input_echo = true;
// decrement remaining sampling budget
n_remain--;
} else {
@ -606,24 +602,26 @@ int main_r2ai_message(const char *message) {
// push the prompt in the sampling context in order to apply repetition penalties later
// for the prompt, we don't apply grammar rules
llama_sampling_accept (ctx_sampling, ctx, embd_inp[n_consumed], false);
n_consumed++;
if ((int) embd.size() >= params.n_batch) {
break;
}
}
}
// display text
if (input_echo) {
for (auto id : embd) {
const std::string token_str = llama_token_to_piece(ctx, id);
printf("%s", token_str.c_str());
if (embd.size() > 1) {
input_tokens.push_back(id);
const std::string token_str = llama_token_to_piece (ctx, id);
if (token_str == "INST") {
// HACK
eos_found = true;
break;
}
printf ("%s", token_str.c_str ());
if (embd.size () > 1) {
input_tokens.push_back (id);
} else {
output_tokens.push_back(id);
output_tokens.push_back (id);
output_ss << token_str;
}
}
@ -635,13 +633,12 @@ int main_r2ai_message(const char *message) {
console::set_display(console::reset);
}
#endif
// if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) {
// check for reverse prompt in the last n_prev tokens
if (!params.antiprompt.empty()) {
if (!params.antiprompt.empty ()) {
const int n_prev = 32;
const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
const std::string last_output = llama_sampling_prev_str (ctx_sampling, ctx, n_prev);
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
@ -650,10 +647,10 @@ int main_r2ai_message(const char *message) {
for (std::string & antiprompt : params.antiprompt) {
size_t extra_padding = params.interactive ? 0 : 2;
size_t search_start_pos = last_output.length() > static_cast<size_t>(antiprompt.length() + extra_padding)
? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
? last_output.length () - static_cast<size_t>(antiprompt.length() + extra_padding)
: 0;
if (last_output.find(antiprompt, search_start_pos) != std::string::npos) {
if (last_output.find (antiprompt, search_start_pos) != std::string::npos) {
if (params.interactive) {
is_interacting = true;
}
@ -664,14 +661,15 @@ int main_r2ai_message(const char *message) {
}
// deal with end of text token in interactive mode
if (llama_sampling_last (ctx_sampling) == llama_token_eos(model)) {
// printf ("found EOS token\n");
// printf ("%d %d\n", llama_sampling_last (ctx_sampling), llama_token_eos (model));
if (eos_found || llama_sampling_last (ctx_sampling) == llama_token_eos (model)) {
// printf ("found EOS token\n");
if (params.interactive) {
if (!params.antiprompt.empty()) {
// tokenize and inject first reverse prompt
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true);
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
embd_inp.insert (embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
is_antiprompt = true;
}
@ -680,16 +678,19 @@ int main_r2ai_message(const char *message) {
} else if (params.instruct || params.chatml) {
is_interacting = true;
}
if (breakloop) {
// printf ("BreakLoop\n");
break;
}
}
if (!breakloop)
if (message != nullptr || (n_past > 0 && is_interacting)) {
printf ("message is not nul\n");
if (params.instruct || params.chatml) {
printf ("\n> ");
}
if (params.input_prefix_bos) {
embd_inp.push_back(llama_token_bos(model));
embd_inp.push_back (llama_token_bos(model));
}
std::string buffer;
@ -702,20 +703,24 @@ int main_r2ai_message(const char *message) {
// console::set_display(console::user_input);
std::string line;
if (breakloop) {
// printf ("BreakLoop\n");
// break;
}
bool another_line = true;
if (message == nullptr || message == NULL) {
if (message == nullptr) {
// initial prompt
do {
another_line = cxxreadline(line, params.multiline_input);
another_line = cxxreadline (line, params.multiline_input);
buffer += line;
} while (another_line);
if (stdin_borken) {
break;
}
} else {
printf ("kkkkk\n");
buffer += message;
// message = nullptr;
// message = nullptr;
breakloop = true;
}
@ -727,13 +732,13 @@ int main_r2ai_message(const char *message) {
if (buffer.length() > 1) {
// printf ("POP BACK %d\n", buffer.length());
// append input suffix if any
#if 0
if (!params.input_suffix.empty()) {
// printf ("appending input suffix: '%s'\n", params.input_suffix.c_str());
// printf("%s", params.input_suffix.c_str());
}
// printf ("buffer: '%s'\n", buffer.c_str());
#endif
const size_t original_size = embd_inp.size();
// instruct mode: insert instruction prefix
@ -747,7 +752,7 @@ int main_r2ai_message(const char *message) {
embd_inp.insert(embd_inp.end(), cml_pfx.begin(), cml_pfx.end());
}
if (params.escape) {
process_escapes(buffer);
process_escapes (buffer);
}
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
@ -780,8 +785,9 @@ int main_r2ai_message(const char *message) {
}
if (n_past > 0) {
// printf ("npast > 0\n");
if (is_interacting) {
llama_sampling_reset(ctx_sampling);
llama_sampling_reset (ctx_sampling);
}
is_interacting = false;
}
@ -803,11 +809,25 @@ int main_r2ai_message(const char *message) {
return 0;
}
extern bool main_r2ai_preinit(int argc, char **argv) {
const char *_argv[] = {"r2ai"};
if (argc < 1 || argv == nullptr) {
argc = 1;
argv = (char **)_argv;
}
if (!gpt_params_parse (argc, argv, params)) {
return false;
}
return true;
}
#if !R2AI
int main(int argc, char **argv) {
if (!gpt_params_parse (argc, argv, params)) {
if (!main_r2ai_preinit (argc, argv)) {
return 1;
}
if (!gpt_params_parse (argc, (char **)argv, params)) {
}
main_r2ai_init (R2AI_MODEL_PATH);
main_r2ai_message (NULL); // if null just get into the repl
main_r2ai_fini ();

View File

@ -33,6 +33,9 @@
static void r2ai_parseflag(RCore *core, const char *input) {
switch (*input) {
case 'e':
r_core_cmd0 (core, "-e r2ai.");
break;
case 'v':
r_cons_printf ("r2ai-native-v0.1\n");
#if 0
@ -46,11 +49,9 @@ static void r2ai_parseflag(RCore *core, const char *input) {
}
}
static void r2ai_init(RCore *core) {
}
extern int main_r2ai_message(const char *message);
extern bool main_r2ai_init(const char *model_path);
extern bool main_r2ai_preinit(int argc, char **argv);
static void r2ai_message(RCore *core, const char *input) {
const char *model_path = r_config_get (core->config, "r2ai.model");
@ -60,7 +61,6 @@ static void r2ai_message(RCore *core, const char *input) {
char *s = r_str_newf ("<s>[INST]%s[/INST]%s</s>%s", prompt, prompt_reply, input);
main_r2ai_message (input);
free (s);
r2ai_init (core);
}
static int r_cmd_r2ai_init(void *user, const char *input) {
@ -75,6 +75,7 @@ static int r_cmd_r2ai_init(void *user, const char *input) {
} else {
R_LOG_INFO ("Can't init");
}
main_r2ai_preinit (0, NULL);
return true;
}