diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 23185a5b6cd..8ddfbcaaeda 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -78,7 +78,6 @@ bool utf8_check_validity(const char* str, size_t length) { return true; // All bytes were valid } -std::string token_buffer; } // namespace namespace executorch_jni { @@ -89,21 +88,11 @@ class ExecuTorchLlmCallbackJni constexpr static const char* kJavaDescriptor = "Lorg/pytorch/executorch/extension/llm/LlmCallback;"; - void onResult(std::string result) const { + void onResult(const std::string& result) const { static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic(); static const auto method = cls->getMethod)>("onResult"); - - token_buffer += result; - if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) { - ET_LOG( - Info, "Current token buffer is not valid UTF-8. Waiting for more."); - return; - } - result = token_buffer; - token_buffer = ""; - facebook::jni::local_ref s = facebook::jni::make_jstring(result); - method(self(), s); + method(self(), facebook::jni::make_jstring(result)); } void onStats(const llm::Stats& result) const { @@ -245,6 +234,19 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint num_bos, jint num_eos) { float effective_temperature = temperature >= 0 ? temperature : temperature_; + std::string token_buffer; + auto token_callback = [callback, &token_buffer](const std::string& token) { + token_buffer += token; + if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) { + ET_LOG( + Info, "Current token buffer is not valid UTF-8. Waiting for more."); + return; + } + std::string result = token_buffer; + token_buffer.clear(); + callback->onResult(result); + }; + if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { std::vector inputs = prefill_inputs_; prefill_inputs_.clear(); @@ -261,7 +263,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { multi_modal_runner_->generate( std::move(inputs), config, - [callback](const std::string& result) { callback->onResult(result); }, + token_callback, [callback](const llm::Stats& result) { callback->onStats(result); }); } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { executorch::extension::llm::GenerationConfig config{ @@ -274,7 +276,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { runner_->generate( prompt->toStdString(), config, - [callback](std::string result) { callback->onResult(result); }, + token_callback, [callback](const llm::Stats& result) { callback->onStats(result); }); } return 0;