First working version of r2ai-native, still far from clean
This commit is contained in:
parent
5d87f1f174
commit
5405e69b65
|
@ -35,6 +35,7 @@
|
|||
#endif
|
||||
|
||||
#define R2AI_MODEL_PATH "/tmp/mistral-7b-v0.1.Q2_K.gguf"
|
||||
// #define R2AI_MODEL_PATH "/tmp/Utopia-13B.q5_k_m.gguf"
|
||||
// Abuse globals to make things easier to refactor this shitty code
|
||||
static gpt_params params; // SHOULDNT BE A GLOBAL
|
||||
static std::string path_session;
|
||||
|
@ -81,29 +82,28 @@ static void sigint_handler(int signo) {
|
|||
#endif
|
||||
|
||||
static void null_log(ggml_log_level level, const char * text, void * user_data) {
|
||||
(void) level;
|
||||
(void) user_data;
|
||||
// LOG_TEE("%s", text);
|
||||
(void) level;
|
||||
(void) user_data;
|
||||
// LOG_TEE ("%s", text);
|
||||
}
|
||||
|
||||
static bool cxxreadline(std::string & line, bool multiline_input) {
|
||||
#if defined(_WIN32)
|
||||
std::wstring wline;
|
||||
if (!std::getline(std::wcin, wline)) {
|
||||
if (!std::getline (std::wcin, wline)) {
|
||||
// Input stream is bad or EOF received
|
||||
line.clear();
|
||||
GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0);
|
||||
line.clear ();
|
||||
GenerateConsoleCtrlEvent (CTRL_C_EVENT, 0);
|
||||
return false;
|
||||
}
|
||||
|
||||
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
|
||||
line.resize(size_needed);
|
||||
WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
|
||||
int size_needed = WideCharToMultiByte (CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
|
||||
line.resize (size_needed);
|
||||
WideCharToMultiByte (CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
|
||||
#else
|
||||
if (!std::getline(std::cin, line)) {
|
||||
if (!std::getline (std::cin, line)) {
|
||||
// Input stream is bad or EOF received
|
||||
stdin_borken = true;
|
||||
line.clear();
|
||||
line.clear ();
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
@ -129,7 +129,9 @@ bool main_r2ai_init(const char *model_path) {
|
|||
params.input_suffix = "[/INST]";
|
||||
|
||||
params.antiprompt.push_back("[INST]");
|
||||
params.antiprompt.push_back("</s>");
|
||||
params.interactive = true;
|
||||
params.interactive_first = true;
|
||||
params.model = model_path; //
|
||||
// params.prompt = "<s>"; // Act as a helpful assistant for radare2, your name is r2ai.";
|
||||
/// params.prompt = "<s>[INST][/INST]"; // <s>[INST]Act as a helpful assistant for radare2, your name is r2ai[/INST]Sure!";
|
||||
|
@ -150,10 +152,10 @@ bool main_r2ai_init(const char *model_path) {
|
|||
}
|
||||
#endif
|
||||
g_params = ¶ms;
|
||||
llama_log_set(null_log, nullptr);
|
||||
llama_log_set (null_log, nullptr);
|
||||
// log_set_target(log_filename_generator("main", "log"));
|
||||
// log_dump_cmdline (argc, argv);
|
||||
sparams = &g_params->sparams;
|
||||
sparams = &g_params->sparams;
|
||||
// TODO: Dump params ?
|
||||
//printf ("Params perplexity: %s\n", LOG_TOSTR(params.perplexity));
|
||||
|
||||
|
@ -188,9 +190,9 @@ bool main_r2ai_init(const char *model_path) {
|
|||
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
|
||||
}
|
||||
|
||||
LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
|
||||
LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
|
||||
LOG_TEE("%s: seed = %u\n", __func__, g_params.seed);
|
||||
LOG_TEE ("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
|
||||
LOG_TEE ("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
|
||||
LOG_TEE ("%s: seed = %u\n", __func__, g_params.seed);
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
|
@ -251,7 +253,7 @@ bool main_r2ai_init(const char *model_path) {
|
|||
|
||||
// Should not run without any tokens
|
||||
if (embd_inp.empty()) {
|
||||
embd_inp.push_back(llama_token_bos(model));
|
||||
embd_inp.push_back (llama_token_bos(model));
|
||||
printf ("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
||||
}
|
||||
|
||||
|
@ -297,7 +299,7 @@ bool main_r2ai_init(const char *model_path) {
|
|||
#endif
|
||||
|
||||
// remove any "future" tokens that we might have inherited from the previous session
|
||||
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
|
||||
llama_kv_cache_seq_rm (ctx, -1, n_matching_session_tokens, -1);
|
||||
}
|
||||
#if 0
|
||||
LOGLN(
|
||||
|
@ -428,7 +430,6 @@ static void main_r2ai_fini(void) {
|
|||
int main_r2ai_message(const char *message) {
|
||||
bool is_antiprompt = false;
|
||||
bool input_echo = false;
|
||||
|
||||
int n_past = 0;
|
||||
int n_remain = params.n_predict;
|
||||
int n_consumed = 0;
|
||||
|
@ -447,10 +448,11 @@ int main_r2ai_message(const char *message) {
|
|||
|
||||
ctx_sampling = llama_sampling_init (*sparams);
|
||||
bool breakloop = false;
|
||||
bool eos_found = false;
|
||||
|
||||
while (!breakloop && ( (n_remain != 0 && !is_antiprompt) || params.interactive)) {
|
||||
while (( (n_remain != 0 && !is_antiprompt) || params.interactive)) {
|
||||
// predict
|
||||
if (!embd.empty()) {
|
||||
if (!embd.empty ()) {
|
||||
// Note: n_ctx - 4 here is to match the logic for commandline prompt handling via
|
||||
// --prompt or --file which uses the same value.
|
||||
int max_embd_size = n_ctx - 4;
|
||||
|
@ -458,12 +460,12 @@ int main_r2ai_message(const char *message) {
|
|||
// Ensure the input doesn't exceed the context size by truncating embd if necessary.
|
||||
if ((int) embd.size() > max_embd_size) {
|
||||
const int skipped_tokens = (int) embd.size() - max_embd_size;
|
||||
embd.resize(max_embd_size);
|
||||
embd.resize (max_embd_size);
|
||||
|
||||
// console::set_display(console::error);
|
||||
printf("<<input too long: skipped %d token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
|
||||
printf ("<<input too long: skipped %d token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
|
||||
// console::set_display(console::reset);
|
||||
fflush(stdout);
|
||||
fflush (stdout);
|
||||
}
|
||||
|
||||
// infinite text generation via context swapping
|
||||
|
@ -477,13 +479,13 @@ int main_r2ai_message(const char *message) {
|
|||
}
|
||||
|
||||
const int n_left = n_past - params.n_keep - 1;
|
||||
const int n_discard = n_left/2;
|
||||
const int n_discard = n_left / 2;
|
||||
|
||||
printf ("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||
llama_kv_cache_seq_shift (ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
||||
|
||||
n_past -= n_discard;
|
||||
|
||||
|
@ -495,21 +497,19 @@ int main_r2ai_message(const char *message) {
|
|||
printf ("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
|
||||
printf ("clear session path\n");
|
||||
#endif
|
||||
path_session.clear();
|
||||
path_session.clear ();
|
||||
}
|
||||
|
||||
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
|
||||
if (n_session_consumed < (int) session_tokens.size()) {
|
||||
size_t i = 0;
|
||||
for ( ; i < embd.size(); i++) {
|
||||
for ( ; i < embd.size (); i++) {
|
||||
if (embd[i] != session_tokens[n_session_consumed]) {
|
||||
session_tokens.resize(n_session_consumed);
|
||||
break;
|
||||
}
|
||||
|
||||
n_past++;
|
||||
n_session_consumed++;
|
||||
|
||||
if (n_session_consumed >= (int) session_tokens.size()) {
|
||||
i++;
|
||||
break;
|
||||
|
@ -524,7 +524,7 @@ int main_r2ai_message(const char *message) {
|
|||
// embd is typically prepared beforehand to fit within a batch, but not always
|
||||
if (ctx_guidance) {
|
||||
int input_size = 0;
|
||||
llama_token * input_buf = NULL;
|
||||
llama_token *input_buf = NULL;
|
||||
if (n_past_guidance < (int) guidance_inp.size()) {
|
||||
// Guidance context should have the same data with these modifications:
|
||||
//
|
||||
|
@ -550,8 +550,8 @@ int main_r2ai_message(const char *message) {
|
|||
|
||||
for (int i = 0; i < input_size; i += params.n_batch) {
|
||||
int n_eval = std::min(input_size - i, params.n_batch);
|
||||
if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
|
||||
LOG_TEE("%s : failed to eval\n", __func__);
|
||||
if (llama_decode (ctx_guidance, llama_batch_get_one (input_buf + i, n_eval, n_past_guidance, 0))) {
|
||||
LOG_TEE ("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
@ -565,26 +565,23 @@ int main_r2ai_message(const char *message) {
|
|||
if (n_eval > params.n_batch) {
|
||||
n_eval = params.n_batch;
|
||||
}
|
||||
|
||||
// printf ("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
|
||||
// printf ("DECODE1 %d embd=%d\n", n_eval, embd[i]);
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
|
||||
if (llama_decode (ctx, llama_batch_get_one (&embd[i], n_eval, n_past, 0))) {
|
||||
LOG_TEE("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
n_past += n_eval;
|
||||
// printf ("n_past = %d\n", n_past);
|
||||
}
|
||||
if (!embd.empty() && !path_session.empty()) {
|
||||
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
|
||||
session_tokens.insert (session_tokens.end(), embd.begin(), embd.end());
|
||||
n_session_consumed = session_tokens.size();
|
||||
}
|
||||
}
|
||||
|
||||
embd.clear();
|
||||
embd_guidance.clear();
|
||||
embd.clear ();
|
||||
embd_guidance.clear ();
|
||||
|
||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||
const llama_token id = llama_sampling_sample (ctx_sampling, ctx, ctx_guidance);
|
||||
|
@ -593,7 +590,6 @@ int main_r2ai_message(const char *message) {
|
|||
embd.push_back (id);
|
||||
// echo this to console
|
||||
input_echo = true;
|
||||
|
||||
// decrement remaining sampling budget
|
||||
n_remain--;
|
||||
} else {
|
||||
|
@ -606,24 +602,26 @@ int main_r2ai_message(const char *message) {
|
|||
// push the prompt in the sampling context in order to apply repetition penalties later
|
||||
// for the prompt, we don't apply grammar rules
|
||||
llama_sampling_accept (ctx_sampling, ctx, embd_inp[n_consumed], false);
|
||||
|
||||
n_consumed++;
|
||||
if ((int) embd.size() >= params.n_batch) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// display text
|
||||
if (input_echo) {
|
||||
for (auto id : embd) {
|
||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||
printf("%s", token_str.c_str());
|
||||
|
||||
if (embd.size() > 1) {
|
||||
input_tokens.push_back(id);
|
||||
const std::string token_str = llama_token_to_piece (ctx, id);
|
||||
if (token_str == "INST") {
|
||||
// HACK
|
||||
eos_found = true;
|
||||
break;
|
||||
}
|
||||
printf ("%s", token_str.c_str ());
|
||||
if (embd.size () > 1) {
|
||||
input_tokens.push_back (id);
|
||||
} else {
|
||||
output_tokens.push_back(id);
|
||||
output_tokens.push_back (id);
|
||||
output_ss << token_str;
|
||||
}
|
||||
}
|
||||
|
@ -635,13 +633,12 @@ int main_r2ai_message(const char *message) {
|
|||
console::set_display(console::reset);
|
||||
}
|
||||
#endif
|
||||
|
||||
// if not currently processing queued inputs;
|
||||
if ((int) embd_inp.size() <= n_consumed) {
|
||||
// check for reverse prompt in the last n_prev tokens
|
||||
if (!params.antiprompt.empty()) {
|
||||
if (!params.antiprompt.empty ()) {
|
||||
const int n_prev = 32;
|
||||
const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
|
||||
const std::string last_output = llama_sampling_prev_str (ctx_sampling, ctx, n_prev);
|
||||
|
||||
is_antiprompt = false;
|
||||
// Check if each of the reverse prompts appears at the end of the output.
|
||||
|
@ -650,10 +647,10 @@ int main_r2ai_message(const char *message) {
|
|||
for (std::string & antiprompt : params.antiprompt) {
|
||||
size_t extra_padding = params.interactive ? 0 : 2;
|
||||
size_t search_start_pos = last_output.length() > static_cast<size_t>(antiprompt.length() + extra_padding)
|
||||
? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
|
||||
? last_output.length () - static_cast<size_t>(antiprompt.length() + extra_padding)
|
||||
: 0;
|
||||
|
||||
if (last_output.find(antiprompt, search_start_pos) != std::string::npos) {
|
||||
if (last_output.find (antiprompt, search_start_pos) != std::string::npos) {
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
}
|
||||
|
@ -664,14 +661,15 @@ int main_r2ai_message(const char *message) {
|
|||
}
|
||||
|
||||
// deal with end of text token in interactive mode
|
||||
if (llama_sampling_last (ctx_sampling) == llama_token_eos(model)) {
|
||||
// printf ("found EOS token\n");
|
||||
// printf ("%d %d\n", llama_sampling_last (ctx_sampling), llama_token_eos (model));
|
||||
if (eos_found || llama_sampling_last (ctx_sampling) == llama_token_eos (model)) {
|
||||
// printf ("found EOS token\n");
|
||||
|
||||
if (params.interactive) {
|
||||
if (!params.antiprompt.empty()) {
|
||||
// tokenize and inject first reverse prompt
|
||||
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true);
|
||||
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
||||
embd_inp.insert (embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
||||
is_antiprompt = true;
|
||||
}
|
||||
|
||||
|
@ -680,16 +678,19 @@ int main_r2ai_message(const char *message) {
|
|||
} else if (params.instruct || params.chatml) {
|
||||
is_interacting = true;
|
||||
}
|
||||
if (breakloop) {
|
||||
// printf ("BreakLoop\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!breakloop)
|
||||
if (message != nullptr || (n_past > 0 && is_interacting)) {
|
||||
printf ("message is not nul\n");
|
||||
if (params.instruct || params.chatml) {
|
||||
printf ("\n> ");
|
||||
}
|
||||
|
||||
if (params.input_prefix_bos) {
|
||||
embd_inp.push_back(llama_token_bos(model));
|
||||
embd_inp.push_back (llama_token_bos(model));
|
||||
}
|
||||
|
||||
std::string buffer;
|
||||
|
@ -702,20 +703,24 @@ int main_r2ai_message(const char *message) {
|
|||
// console::set_display(console::user_input);
|
||||
|
||||
std::string line;
|
||||
if (breakloop) {
|
||||
// printf ("BreakLoop\n");
|
||||
// break;
|
||||
}
|
||||
bool another_line = true;
|
||||
if (message == nullptr || message == NULL) {
|
||||
if (message == nullptr) {
|
||||
// initial prompt
|
||||
do {
|
||||
another_line = cxxreadline(line, params.multiline_input);
|
||||
another_line = cxxreadline (line, params.multiline_input);
|
||||
buffer += line;
|
||||
} while (another_line);
|
||||
if (stdin_borken) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
printf ("kkkkk\n");
|
||||
buffer += message;
|
||||
// message = nullptr;
|
||||
// message = nullptr;
|
||||
breakloop = true;
|
||||
}
|
||||
|
||||
|
@ -727,13 +732,13 @@ int main_r2ai_message(const char *message) {
|
|||
if (buffer.length() > 1) {
|
||||
// printf ("POP BACK %d\n", buffer.length());
|
||||
// append input suffix if any
|
||||
#if 0
|
||||
if (!params.input_suffix.empty()) {
|
||||
// printf ("appending input suffix: '%s'\n", params.input_suffix.c_str());
|
||||
// printf("%s", params.input_suffix.c_str());
|
||||
}
|
||||
|
||||
// printf ("buffer: '%s'\n", buffer.c_str());
|
||||
|
||||
#endif
|
||||
const size_t original_size = embd_inp.size();
|
||||
|
||||
// instruct mode: insert instruction prefix
|
||||
|
@ -747,7 +752,7 @@ int main_r2ai_message(const char *message) {
|
|||
embd_inp.insert(embd_inp.end(), cml_pfx.begin(), cml_pfx.end());
|
||||
}
|
||||
if (params.escape) {
|
||||
process_escapes(buffer);
|
||||
process_escapes (buffer);
|
||||
}
|
||||
|
||||
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
|
||||
|
@ -780,8 +785,9 @@ int main_r2ai_message(const char *message) {
|
|||
}
|
||||
|
||||
if (n_past > 0) {
|
||||
// printf ("npast > 0\n");
|
||||
if (is_interacting) {
|
||||
llama_sampling_reset(ctx_sampling);
|
||||
llama_sampling_reset (ctx_sampling);
|
||||
}
|
||||
is_interacting = false;
|
||||
}
|
||||
|
@ -803,11 +809,25 @@ int main_r2ai_message(const char *message) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
extern bool main_r2ai_preinit(int argc, char **argv) {
|
||||
const char *_argv[] = {"r2ai"};
|
||||
if (argc < 1 || argv == nullptr) {
|
||||
argc = 1;
|
||||
argv = (char **)_argv;
|
||||
}
|
||||
if (!gpt_params_parse (argc, argv, params)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#if !R2AI
|
||||
int main(int argc, char **argv) {
|
||||
if (!gpt_params_parse (argc, argv, params)) {
|
||||
if (!main_r2ai_preinit (argc, argv)) {
|
||||
return 1;
|
||||
}
|
||||
if (!gpt_params_parse (argc, (char **)argv, params)) {
|
||||
}
|
||||
main_r2ai_init (R2AI_MODEL_PATH);
|
||||
main_r2ai_message (NULL); // if null just get into the repl
|
||||
main_r2ai_fini ();
|
||||
|
|
|
@ -33,6 +33,9 @@
|
|||
|
||||
static void r2ai_parseflag(RCore *core, const char *input) {
|
||||
switch (*input) {
|
||||
case 'e':
|
||||
r_core_cmd0 (core, "-e r2ai.");
|
||||
break;
|
||||
case 'v':
|
||||
r_cons_printf ("r2ai-native-v0.1\n");
|
||||
#if 0
|
||||
|
@ -46,11 +49,9 @@ static void r2ai_parseflag(RCore *core, const char *input) {
|
|||
}
|
||||
}
|
||||
|
||||
static void r2ai_init(RCore *core) {
|
||||
}
|
||||
|
||||
extern int main_r2ai_message(const char *message);
|
||||
extern bool main_r2ai_init(const char *model_path);
|
||||
extern bool main_r2ai_preinit(int argc, char **argv);
|
||||
|
||||
static void r2ai_message(RCore *core, const char *input) {
|
||||
const char *model_path = r_config_get (core->config, "r2ai.model");
|
||||
|
@ -60,7 +61,6 @@ static void r2ai_message(RCore *core, const char *input) {
|
|||
char *s = r_str_newf ("<s>[INST]%s[/INST]%s</s>%s", prompt, prompt_reply, input);
|
||||
main_r2ai_message (input);
|
||||
free (s);
|
||||
r2ai_init (core);
|
||||
}
|
||||
|
||||
static int r_cmd_r2ai_init(void *user, const char *input) {
|
||||
|
@ -75,6 +75,7 @@ static int r_cmd_r2ai_init(void *user, const char *input) {
|
|||
} else {
|
||||
R_LOG_INFO ("Can't init");
|
||||
}
|
||||
main_r2ai_preinit (0, NULL);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue