Files
airllm-fork-nodejs/node_modules/node-llama-cpp/llama/addon/AddonSampler.cpp
2026-02-05 15:27:49 +08:00

512 lines
18 KiB
C++

#include <cmath>
#include "common/common.h"
#include "globals/addonLog.h"
#include "ggml.h"
#include "llama.h"
#include "AddonGrammarEvaluationState.h"
#include "AddonSampler.h"
AddonSampler::AddonSampler(const Napi::CallbackInfo& info) : Napi::ObjectWrap<AddonSampler>(info) {
model = Napi::ObjectWrap<AddonModel>::Unwrap(info[0].As<Napi::Object>());
model->Ref();
tokenCandidates.resize(llama_vocab_n_tokens(model->vocab));
tokenCandidates.reserve(llama_vocab_n_tokens(model->vocab));
}
AddonSampler::~AddonSampler() {
dispose();
}
void AddonSampler::dispose() {
if (disposed) {
return;
}
disposed = true;
model->Unref();
freeChain();
if (temperatureSampler != nullptr) {
llama_sampler_free(temperatureSampler);
temperatureSampler = nullptr;
}
if (greedySampler != nullptr) {
llama_sampler_free(greedySampler);
greedySampler = nullptr;
}
if (minPSampler != nullptr) {
llama_sampler_free(minPSampler);
minPSampler = nullptr;
}
if (topKSampler != nullptr) {
llama_sampler_free(topKSampler);
topKSampler = nullptr;
}
if (topPSampler != nullptr) {
llama_sampler_free(topPSampler);
topPSampler = nullptr;
}
if (seedSampler != nullptr) {
llama_sampler_free(seedSampler);
seedSampler = nullptr;
}
if (repeatPenaltySampler != nullptr) {
llama_sampler_free(repeatPenaltySampler);
repeatPenaltySampler = nullptr;
}
if (tokenBiasSampler != nullptr) {
llama_sampler_free(tokenBiasSampler);
tokenBiasSampler = nullptr;
}
if (grammarEvaluationState != nullptr) {
grammarEvaluationState->Unref();
grammarEvaluationState = nullptr;
}
}
void AddonSampler::freeChain() {
if (chain == nullptr) {
return;
}
// ensure existing state of samplers isn't cleared
while (llama_sampler_chain_n(chain) > 0) {
llama_sampler_chain_remove(chain, 0);
}
llama_sampler_free(chain);
chain = nullptr;
}
void AddonSampler::rebuildChainIfNeeded() {
if (disposed) {
throw std::runtime_error("Sampler is disposed");
}
if (chain != nullptr) {
return;
}
auto sampler_params = llama_sampler_chain_default_params();
chain = llama_sampler_chain_init(sampler_params);
if (tokenBiasSampler != nullptr) {
llama_sampler_chain_add(chain, tokenBiasSampler);
}
if (repeatPenaltySampler != nullptr) {
llama_sampler_chain_add(chain, repeatPenaltySampler);
}
if (grammarEvaluationState != nullptr) {
llama_sampler_chain_add(chain, grammarEvaluationState->sampler);
}
if (greedySampler != nullptr) {
llama_sampler_chain_add(chain, greedySampler);
} else {
if (topKSampler != nullptr) {
llama_sampler_chain_add(chain, topKSampler);
}
if (topPSampler != nullptr) {
llama_sampler_chain_add(chain, topPSampler);
}
if (minPSampler != nullptr) {
llama_sampler_chain_add(chain, minPSampler);
}
if (temperatureSampler != nullptr) {
llama_sampler_chain_add(chain, temperatureSampler);
}
if (seedSampler != nullptr) {
llama_sampler_chain_add(chain, seedSampler);
}
}
}
void AddonSampler::acceptToken(llama_token token) {
if (repeatPenaltySampler != nullptr) {
llama_sampler_accept(repeatPenaltySampler, token);
repeatPenalty_lastTokens.push_back(token);
}
if (grammarEvaluationState != nullptr && grammarEvaluationState->sampler != nullptr && !llama_vocab_is_eog(model->vocab, token)) {
llama_sampler_accept(grammarEvaluationState->sampler, token);
}
}
Napi::Value AddonSampler::Dispose(const Napi::CallbackInfo& info) {
dispose();
return info.Env().Undefined();
}
Napi::Value AddonSampler::ApplyConfig(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Sampler is disposed").ThrowAsJavaScriptException();
return info.Env().Undefined();
}
const int32_t n_probs = 0; // Number of probabilities to keep - 0 = disabled
size_t min_keep = std::max(1, n_probs);
Napi::Object config = info[0].As<Napi::Object>();
if (config.Has("temperature")) {
auto temperature = config.Get("temperature").As<Napi::Number>().FloatValue();
if (temperature != temperatureSampler_temperature || !temperatureSampler_initialized) {
temperatureSampler_initialized = true;
temperatureSampler_temperature = temperature;
freeChain();
if (temperatureSampler != nullptr) {
llama_sampler_free(temperatureSampler);
temperatureSampler = nullptr;
}
if (temperatureSampler_temperature <= 0) {
greedySampler = llama_sampler_init_greedy();
} else {
temperatureSampler = llama_sampler_init_temp(temperatureSampler_temperature);
if (greedySampler != nullptr) {
llama_sampler_free(greedySampler);
greedySampler = nullptr;
}
}
}
} else {
if (temperatureSampler != nullptr) {
freeChain();
llama_sampler_free(temperatureSampler);
temperatureSampler = nullptr;
}
if (greedySampler == nullptr) {
greedySampler = llama_sampler_init_greedy();
}
}
if (config.Has("minP")) {
auto minP = config.Get("minP").As<Napi::Number>().FloatValue();
if (minP != minPSampler_minP) {
minPSampler_minP = minP;
freeChain();
if (minPSampler != nullptr) {
llama_sampler_free(minPSampler);
minPSampler = nullptr;
}
if (minPSampler_minP != 0) {
minPSampler = llama_sampler_init_min_p(minPSampler_minP, min_keep);
}
}
} else if (minPSampler != nullptr) {
freeChain();
llama_sampler_free(minPSampler);
minPSampler = nullptr;
}
if (config.Has("topK")) {
auto topK = config.Get("topK").As<Napi::Number>().Int32Value();
if (topK != topKSampler_topK || !topKSampler_initialized) {
topKSampler_initialized = true;
topKSampler_topK = topK;
freeChain();
if (topKSampler != nullptr) {
llama_sampler_free(topKSampler);
topKSampler = nullptr;
}
const int32_t resolved_top_k = topKSampler_topK <= 0
? llama_vocab_n_tokens(model->vocab)
: std::min(topKSampler_topK, llama_vocab_n_tokens(model->vocab));
topKSampler = llama_sampler_init_top_k(resolved_top_k);
}
} else if (topKSampler != nullptr) {
freeChain();
llama_sampler_free(topKSampler);
topKSampler = nullptr;
}
if (config.Has("topP")) {
auto topP = config.Get("topP").As<Napi::Number>().FloatValue();
if (topP != topPSampler_topP) {
topPSampler_topP = topP;
freeChain();
if (topPSampler != nullptr) {
llama_sampler_free(topPSampler);
topPSampler = nullptr;
}
if (topPSampler_topP >= 1) {
topPSampler = llama_sampler_init_top_p(topPSampler_topP, min_keep);
}
}
} else if (topPSampler != nullptr) {
freeChain();
llama_sampler_free(topPSampler);
topPSampler = nullptr;
}
if (config.Has("seed")) {
auto seed = config.Get("seed").As<Napi::Number>().Uint32Value();
if (seed != seedSampler_seed || seedSampler == nullptr) {
seedSampler_seed = seed;
freeChain();
if (seedSampler != nullptr) {
llama_sampler_free(seedSampler);
seedSampler = nullptr;
}
seedSampler = llama_sampler_init_dist(seedSampler_seed);
}
} else if (seedSampler == nullptr) {
freeChain();
seedSampler = llama_sampler_init_dist(time(NULL));
}
if (config.Has("repeatPenaltyTokens")) {
Napi::Uint32Array repeat_penalty_tokens_uint32_array = config.Get("repeatPenaltyTokens").As<Napi::Uint32Array>();
auto repeatPenalty = config.Has("repeatPenalty")
? config.Get("repeatPenalty").As<Napi::Number>().FloatValue()
: 1;
auto repeatPenaltyMaxTokens = config.Has("repeatPenaltyMaxTokens")
? config.Get("repeatPenaltyMaxTokens").As<Napi::Number>().Int32Value()
: 64;
auto repeatPenaltyPresencePenalty = config.Has("repeatPenaltyPresencePenalty")
? config.Get("repeatPenaltyPresencePenalty").As<Napi::Number>().FloatValue()
: 0;
auto repeatPenaltyFrequencyPenalty = config.Has("repeatPenaltyFrequencyPenalty")
? config.Get("repeatPenaltyFrequencyPenalty").As<Napi::Number>().FloatValue()
: 0;
auto repeatPenaltyEnabled = repeatPenalty != 1 && repeatPenaltyMaxTokens > 0;
bool shouldCreateSampler = false;
if (!repeatPenaltyEnabled) {
if (repeatPenaltySampler != nullptr) {
freeChain();
llama_sampler_free(repeatPenaltySampler);
repeatPenaltySampler = nullptr;
}
} else if (repeatPenaltySampler == nullptr) {
freeChain();
shouldCreateSampler = true;
} else {
bool existingSamplerMatchesConfig = true;
existingSamplerMatchesConfig &= repeatPenalty_maxTokens == repeatPenaltyMaxTokens;
existingSamplerMatchesConfig &= repeatPenalty_penalty == repeatPenalty;
existingSamplerMatchesConfig &= repeatPenalty_presencePenalty == repeatPenaltyPresencePenalty;
existingSamplerMatchesConfig &= repeatPenalty_frequencyPenalty == repeatPenaltyFrequencyPenalty;
if (existingSamplerMatchesConfig) {
if (repeat_penalty_tokens_uint32_array.ElementLength() > 0) {
const auto firstToken = static_cast<llama_token>(repeat_penalty_tokens_uint32_array[0]);
if (repeatPenalty_lastTokens.rat(0) != firstToken &&
repeatPenalty_lastTokens.size() == repeatPenalty_maxTokens &&
repeat_penalty_tokens_uint32_array.ElementLength() == repeatPenalty_maxTokens
) {
const auto lastToken = static_cast<llama_token>(repeat_penalty_tokens_uint32_array[repeat_penalty_tokens_uint32_array.ElementLength() - 1]);
llama_sampler_accept(repeatPenaltySampler, lastToken);
repeatPenalty_lastTokens.push_back(lastToken);
}
}
for (size_t i = 0; i < repeat_penalty_tokens_uint32_array.ElementLength() && existingSamplerMatchesConfig; i++) {
auto token = static_cast<llama_token>(repeat_penalty_tokens_uint32_array[i]);
if (i < repeatPenalty_lastTokens.size()) {
existingSamplerMatchesConfig &= repeatPenalty_lastTokens.rat(i) == token;
} else {
llama_sampler_accept(repeatPenaltySampler, token);
repeatPenalty_lastTokens.push_back(token);
}
}
}
if (!existingSamplerMatchesConfig) {
freeChain();
llama_sampler_free(repeatPenaltySampler);
repeatPenaltySampler = nullptr;
shouldCreateSampler = true;
}
}
if (shouldCreateSampler) {
repeatPenaltySampler = llama_sampler_init_penalties(
repeatPenaltyMaxTokens,
repeatPenalty,
repeatPenaltyFrequencyPenalty,
repeatPenaltyPresencePenalty
);
repeatPenalty_lastTokens = RingBuffer<llama_token>(repeatPenaltyMaxTokens);
for (size_t i = 0; i < repeat_penalty_tokens_uint32_array.ElementLength(); i++) {
llama_sampler_accept(repeatPenaltySampler, static_cast<llama_token>(repeat_penalty_tokens_uint32_array[i]));
repeatPenalty_lastTokens.push_back(static_cast<llama_token>(repeat_penalty_tokens_uint32_array[i]));
}
repeatPenalty_maxTokens = repeatPenaltyMaxTokens;
repeatPenalty_penalty = repeatPenalty;
repeatPenalty_presencePenalty = repeatPenaltyPresencePenalty;
repeatPenalty_frequencyPenalty = repeatPenaltyFrequencyPenalty;
}
} else if (repeatPenaltySampler != nullptr) {
freeChain();
llama_sampler_free(repeatPenaltySampler);
repeatPenaltySampler = nullptr;
}
if (config.Has("tokenBiasKeys") && config.Has("tokenBiasValues")) {
Napi::Uint32Array tokenBiasKeys = config.Get("tokenBiasKeys").As<Napi::Uint32Array>();
Napi::Float32Array tokenBiasValues = config.Get("tokenBiasValues").As<Napi::Float32Array>();
if (tokenBiasKeys.ElementLength() == tokenBiasValues.ElementLength() && tokenBiasKeys.ElementLength() > 0) {
bool existingSamplerMatchesConfig = tokenBiasSampler != nullptr;
if (tokenBiasSampler != nullptr && tokenBiasSampler_biases.size() == tokenBiasKeys.ElementLength()) {
for (size_t i = 0; i < tokenBiasKeys.ElementLength() && existingSamplerMatchesConfig; i++) {
existingSamplerMatchesConfig &= tokenBiasSampler_biases[i].token == static_cast<llama_token>(tokenBiasKeys[i]);
existingSamplerMatchesConfig &= tokenBiasSampler_biases[i].bias == tokenBiasValues[i];
}
} else {
existingSamplerMatchesConfig = false;
}
if (!existingSamplerMatchesConfig) {
if (tokenBiasSampler != nullptr) {
freeChain();
llama_sampler_free(tokenBiasSampler);
tokenBiasSampler = nullptr;
}
tokenBiasSampler_biases.clear();
tokenBiasSampler_biases.reserve(tokenBiasKeys.ElementLength());
for (size_t i = 0; i < tokenBiasKeys.ElementLength(); i++) {
tokenBiasSampler_biases.emplace_back(llama_logit_bias { static_cast<llama_token>(tokenBiasKeys[i]), tokenBiasValues[i] });
}
tokenBiasSampler = llama_sampler_init_logit_bias(
llama_vocab_n_tokens(model->vocab),
tokenBiasSampler_biases.size(),
tokenBiasSampler_biases.data()
);
}
} else if (tokenBiasSampler != nullptr) {
freeChain();
llama_sampler_free(tokenBiasSampler);
tokenBiasSampler = nullptr;
}
} else if (tokenBiasSampler != nullptr) {
freeChain();
llama_sampler_free(tokenBiasSampler);
tokenBiasSampler = nullptr;
}
if (config.Has("grammarEvaluationState")) {
const auto configGrammarEvaluationState =
Napi::ObjectWrap<AddonGrammarEvaluationState>::Unwrap(config.Get("grammarEvaluationState").As<Napi::Object>());
if (grammarEvaluationState != configGrammarEvaluationState) {
freeChain();
if (grammarEvaluationState != nullptr) {
grammarEvaluationState->Unref();
grammarEvaluationState = nullptr;
}
grammarEvaluationState = configGrammarEvaluationState;
grammarEvaluationState->Ref();
}
} else if (grammarEvaluationState != nullptr) {
freeChain();
grammarEvaluationState->Unref();
grammarEvaluationState = nullptr;
}
return info.Env().Undefined();
}
Napi::Value AddonSampler::AcceptGrammarEvaluationStateToken(const Napi::CallbackInfo& info) {
AddonGrammarEvaluationState* grammar_evaluation_state =
Napi::ObjectWrap<AddonGrammarEvaluationState>::Unwrap(info[0].As<Napi::Object>());
llama_token tokenId = info[1].As<Napi::Number>().Int32Value();
if ((grammar_evaluation_state)->sampler != nullptr) {
try {
llama_sampler_accept((grammar_evaluation_state)->sampler, tokenId);
} catch (const std::exception & e) {
Napi::Error::New(info.Env(), std::string("Failed to accept token in grammar sampler: ") + e.what()).ThrowAsJavaScriptException();
return info.Env().Undefined();
} catch (...) {
Napi::Error::New(info.Env(), "Failed to accept token in grammar sampler").ThrowAsJavaScriptException();
return info.Env().Undefined();
}
}
return info.Env().Undefined();
}
Napi::Value AddonSampler::CanBeNextTokenForGrammarEvaluationState(const Napi::CallbackInfo& info) {
AddonGrammarEvaluationState* grammar_evaluation_state =
Napi::ObjectWrap<AddonGrammarEvaluationState>::Unwrap(info[0].As<Napi::Object>());
llama_token tokenId = info[1].As<Napi::Number>().Int32Value();
if ((grammar_evaluation_state)->sampler != nullptr) {
std::vector<llama_token_data> candidates;
candidates.reserve(1);
candidates.emplace_back(llama_token_data { tokenId, 1, 0.0f });
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
try {
llama_sampler_apply((grammar_evaluation_state)->sampler, &candidates_p);
} catch (const std::exception & e) {
addonLog(GGML_LOG_LEVEL_DEBUG, std::string("Failed to apply grammar sampler: ") + e.what());
return Napi::Boolean::New(info.Env(), false);
} catch (...) {
return Napi::Boolean::New(info.Env(), false);
}
if (candidates_p.size == 0 || candidates_p.data[0].logit == -INFINITY) {
return Napi::Boolean::New(info.Env(), false);
}
return Napi::Boolean::New(info.Env(), true);
}
return Napi::Boolean::New(info.Env(), false);
}
void AddonSampler::init(Napi::Object exports) {
exports.Set(
"AddonSampler",
DefineClass(
exports.Env(),
"AddonSampler",
{
InstanceMethod("dispose", &AddonSampler::Dispose),
InstanceMethod("applyConfig", &AddonSampler::ApplyConfig),
StaticMethod("acceptGrammarEvaluationStateToken", &AddonSampler::AcceptGrammarEvaluationStateToken),
StaticMethod("canBeNextTokenForGrammarEvaluationState", &AddonSampler::CanBeNextTokenForGrammarEvaluationState),
}
)
);
}