Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ bool utf8_check_validity(const char* str, size_t length) {
return true; // All bytes were valid
}

std::string token_buffer;
} // namespace

namespace executorch_jni {
Expand All @@ -89,21 +88,11 @@ class ExecuTorchLlmCallbackJni
constexpr static const char* kJavaDescriptor =
"Lorg/pytorch/executorch/extension/llm/LlmCallback;";

void onResult(std::string result) const {
void onResult(const std::string& result) const {
static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic();
static const auto method =
cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onResult");

token_buffer += result;
if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) {
ET_LOG(
Info, "Current token buffer is not valid UTF-8. Waiting for more.");
return;
}
result = token_buffer;
token_buffer = "";
facebook::jni::local_ref<jstring> s = facebook::jni::make_jstring(result);
method(self(), s);
method(self(), facebook::jni::make_jstring(result));
}

void onStats(const llm::Stats& result) const {
Expand Down Expand Up @@ -245,6 +234,19 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
jint num_bos,
jint num_eos) {
float effective_temperature = temperature >= 0 ? temperature : temperature_;
std::string token_buffer;
auto token_callback = [callback, &token_buffer](const std::string& token) {
token_buffer += token;
if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) {
ET_LOG(
Info, "Current token buffer is not valid UTF-8. Waiting for more.");
return;
}
std::string result = token_buffer;
token_buffer.clear();
callback->onResult(result);
};

if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
prefill_inputs_.clear();
Expand All @@ -261,7 +263,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
multi_modal_runner_->generate(
std::move(inputs),
config,
[callback](const std::string& result) { callback->onResult(result); },
token_callback,
[callback](const llm::Stats& result) { callback->onStats(result); });
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
executorch::extension::llm::GenerationConfig config{
Expand All @@ -274,7 +276,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
runner_->generate(
prompt->toStdString(),
config,
[callback](std::string result) { callback->onResult(result); },
token_callback,
[callback](const llm::Stats& result) { callback->onStats(result); });
}
return 0;
Expand Down
Loading