1691 lines
81 KiB
JavaScript
1691 lines
81 KiB
JavaScript
import path from "path";
|
|
import { acquireLock, AsyncDisposeAggregator, DisposeAggregator, DisposedError, EventRelay, withLock } from "lifecycle-utils";
|
|
import { removeNullFields } from "../../utils/removeNullFields.js";
|
|
import { compareTokens } from "../../utils/compareTokens.js";
|
|
import { DisposeGuard } from "../../utils/DisposeGuard.js";
|
|
import { TokenMeter } from "../TokenMeter.js";
|
|
import { UnsupportedError } from "../../utils/UnsupportedError.js";
|
|
import { pushAll } from "../../utils/pushAll.js";
|
|
import { safeEventCallback } from "../../utils/safeEventCallback.js";
|
|
import { GgufArchitectureType } from "../../gguf/types/GgufMetadataTypes.js";
|
|
import { resolveBatchItemsPrioritizationStrategy } from "./utils/resolveBatchItemsPrioritizationStrategy.js";
|
|
import { LlamaSampler } from "./LlamaSampler.js";
|
|
import { padSafeContextSize } from "./utils/padSafeContextSize.js";
|
|
const defaultLoraScale = 1;
|
|
const shrinkRetriesMinContextSize = 4096;
|
|
const defaultMaxPunishTokens = 64;
|
|
const defaultFailedCreationRemedy = {
|
|
retries: 16,
|
|
autoContextSizeShrink: 0.16
|
|
};
|
|
const defaultEvaluationPriority = 5;
|
|
const decodeSyncWorkaround = {
|
|
vulkanLock: {}
|
|
};
|
|
export class LlamaContext {
|
|
/** @internal */ _llama;
|
|
/** @internal */ _ctx;
|
|
/** @internal */ _onReclaimUnusedSequenceId = new EventRelay();
|
|
/** @internal */ _backendContextDisposeGuard;
|
|
/** @internal */ _model;
|
|
/** @internal */ _contextSize;
|
|
/** @internal */ _batchSize;
|
|
/** @internal */ _flashAttention;
|
|
/** @internal */ _idealThreads;
|
|
/** @internal */ _minThreads;
|
|
/** @internal */ _performanceTracking;
|
|
/** @internal */ _totalSequences;
|
|
/** @internal */ _unusedSequenceIds = [];
|
|
/** @internal */ _batchingOptions;
|
|
/** @internal */ _swaFullCache = false;
|
|
/** @internal */ _queuedDecodeSequenceIds = new Set();
|
|
/** @internal */ _queuedDecodes = [];
|
|
/** @internal */ _disposeAggregator = new AsyncDisposeAggregator();
|
|
/** @internal */ _modelPreventDisposalHandle;
|
|
/** @internal */ _loraAdapters = new Set();
|
|
/** @internal */ _nextGeneratedSequenceId = 0;
|
|
/** @internal */ _dispatchDecodeScheduled = false;
|
|
/** @internal */ _batchDispatchPending = false;
|
|
/** @internal */ _threadSplitterConsumer;
|
|
/** @internal */ _freeReservedThreadsTimeout;
|
|
/** @internal */ _currentDispatchBatchHandle = {};
|
|
/** @internal */ _allocatedContextSize;
|
|
/** @internal */ _disposed = false;
|
|
onDispose = new EventRelay();
|
|
constructor({ _model }, { sequences, contextSize, batchSize, flashAttention = _model.defaultContextFlashAttention, threads, batching: { dispatchSchedule: batchingDispatchSchedule = "nextCycle", itemPrioritizationStrategy: batchingItemsPrioritizationStrategy = "maximumParallelism" } = {}, swaFullCache = _model.defaultContextSwaFullCache, performanceTracking = false, _embeddings, _ranking }) {
|
|
if (_model.disposed)
|
|
throw new DisposedError();
|
|
const kvUnified = false;
|
|
this._llama = _model._llama;
|
|
this._model = _model;
|
|
this._backendContextDisposeGuard = new DisposeGuard([this._model._backendModelDisposeGuard]);
|
|
this._modelPreventDisposalHandle = this._model._backendModelDisposeGuard.createPreventDisposalHandle();
|
|
this._totalSequences = Math.max(1, Math.floor(sequences));
|
|
this._contextSize = kvUnified
|
|
? Math.floor(padSafeContextSize(Math.max(2, contextSize) * this._totalSequences, "up") / this._totalSequences)
|
|
: padSafeContextSize(Math.max(2, contextSize), "up");
|
|
this._batchSize = Math.max(batchSize, this._totalSequences);
|
|
this._flashAttention = flashAttention;
|
|
this._idealThreads = typeof threads === "number"
|
|
? this._llama._threadsSplitter.normalizeThreadsValue(threads)
|
|
: this._llama._threadsSplitter.normalizeThreadsValue(threads?.ideal ?? (this._llama.maxThreads === 0
|
|
? this._llama.cpuMathCores
|
|
: this._llama.maxThreads));
|
|
this._minThreads = Math.max(1, typeof threads === "number"
|
|
? 1
|
|
: this._llama._threadsSplitter.normalizeThreadsValue(threads?.min ?? 1));
|
|
this._performanceTracking = !!performanceTracking;
|
|
this._swaFullCache = !!swaFullCache;
|
|
this._ctx = new this._llama._bindings.AddonContext(this._model._model, removeNullFields({
|
|
contextSize: padSafeContextSize(this._contextSize * this._totalSequences, "up"), // each sequence needs its own <contextSize> of cells
|
|
batchSize: this._batchSize + ((!this._swaFullCache && this.model.fileInsights.swaSize != null && this.model.fileInsights.swaSize > 0)
|
|
? 1 // +1 to handle edge cases with SWA KV cache
|
|
: 0),
|
|
sequences: this._totalSequences,
|
|
flashAttention: this._flashAttention,
|
|
threads: this._idealThreads,
|
|
embeddings: _embeddings,
|
|
ranking: _ranking,
|
|
performanceTracking: this._performanceTracking,
|
|
swaFullCache: this._swaFullCache
|
|
}));
|
|
this._batchingOptions = {
|
|
dispatchSchedule: batchingDispatchSchedule,
|
|
itemPrioritizationStrategy: batchingItemsPrioritizationStrategy
|
|
};
|
|
this._reclaimUnusedSequenceId = this._reclaimUnusedSequenceId.bind(this);
|
|
this._freeReservedThreads = this._freeReservedThreads.bind(this);
|
|
this._disposeAggregator.add(() => {
|
|
this._disposed = true;
|
|
});
|
|
this._disposeAggregator.add(this._onReclaimUnusedSequenceId);
|
|
this._disposeAggregator.add(this.onDispose.dispatchEvent);
|
|
this._disposeAggregator.add(this.model.onDispose.createListener(disposeContextIfReferenced.bind(null, new WeakRef(this))));
|
|
this._disposeAggregator.add(async () => {
|
|
await this._backendContextDisposeGuard.acquireDisposeLock();
|
|
await this._ctx.dispose();
|
|
this._modelPreventDisposalHandle.dispose();
|
|
});
|
|
}
|
|
async dispose() {
|
|
if (this._disposed)
|
|
return;
|
|
this._disposed = true;
|
|
await this._disposeAggregator.dispose();
|
|
}
|
|
/** @hidden */
|
|
[Symbol.asyncDispose]() {
|
|
return this.dispose();
|
|
}
|
|
get disposed() {
|
|
return this._disposed;
|
|
}
|
|
get model() {
|
|
return this._model;
|
|
}
|
|
get contextSize() {
|
|
return this._contextSize;
|
|
}
|
|
get batchSize() {
|
|
return this._batchSize;
|
|
}
|
|
get flashAttention() {
|
|
return this._flashAttention;
|
|
}
|
|
/**
|
|
* The actual size of the state in the memory in bytes.
|
|
* This value is provided by `llama.cpp` and doesn't include all the memory overhead of the context.
|
|
*/
|
|
get stateSize() {
|
|
this._ensureNotDisposed();
|
|
return this._ctx.getStateSize();
|
|
}
|
|
/** The number of threads currently used to evaluate tokens */
|
|
get currentThreads() {
|
|
this._ensureNotDisposed();
|
|
return this._ctx.getThreads();
|
|
}
|
|
/**
|
|
* The number of threads that are preferred to be used to evaluate tokens.
|
|
*
|
|
* The actual number of threads used may be lower when other evaluations are running in parallel.
|
|
*/
|
|
get idealThreads() {
|
|
return this._idealThreads;
|
|
}
|
|
getAllocatedContextSize() {
|
|
this._ensureNotDisposed();
|
|
if (this._allocatedContextSize == null)
|
|
this._allocatedContextSize = this._ctx.getContextSize();
|
|
return this._allocatedContextSize;
|
|
}
|
|
get totalSequences() {
|
|
return this._totalSequences;
|
|
}
|
|
get sequencesLeft() {
|
|
return this._totalSequences - this._nextGeneratedSequenceId + this._unusedSequenceIds.length;
|
|
}
|
|
/**
|
|
* Before calling this method, make sure to call `sequencesLeft` to check if there are any sequences left.
|
|
* When there are no sequences left, this method will throw an error.
|
|
*/
|
|
getSequence(options = {}) {
|
|
const { contextShift: { size: contextShiftSize = Math.min(100, Math.ceil(this.contextSize / 2)), strategy: contextShiftStrategy = "eraseBeginning" } = {}, tokenPredictor, _tokenMeter } = options;
|
|
this._ensureNotDisposed();
|
|
const nextSequenceId = this._popSequenceId();
|
|
if (nextSequenceId == null)
|
|
throw new Error("No sequences left");
|
|
return LlamaContextSequence._create({
|
|
sequenceId: nextSequenceId,
|
|
context: this,
|
|
tokenMeter: _tokenMeter,
|
|
contextShift: {
|
|
size: contextShiftSize,
|
|
strategy: contextShiftStrategy
|
|
},
|
|
tokenPredictor
|
|
});
|
|
}
|
|
dispatchPendingBatch() {
|
|
this._currentDispatchBatchHandle = {};
|
|
this._dispatchDecodeScheduled = false;
|
|
if (this._batchDispatchPending)
|
|
return;
|
|
this._batchDispatchPending = true;
|
|
void withLock([this, "context"], async () => {
|
|
this._currentDispatchBatchHandle = {};
|
|
this._dispatchDecodeScheduled = false;
|
|
this._batchDispatchPending = false;
|
|
let shouldHaveAnotherLoop = this._queuedDecodes.length > 0;
|
|
const queuedDecodeToMappedLogits = new Map();
|
|
const resolvePrioritizationStrategy = () => {
|
|
try {
|
|
this._ensureNotDisposed();
|
|
return resolveBatchItemsPrioritizationStrategy(this._batchingOptions.itemPrioritizationStrategy);
|
|
}
|
|
catch (err) {
|
|
this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err);
|
|
}
|
|
return null;
|
|
};
|
|
const getOrderedQueuedDecodes = (prioritizationStrategy) => {
|
|
const batchItemToQueuedDecodeMap = new Map();
|
|
const batchItemsList = [];
|
|
for (const queuedDecode of this._queuedDecodes) {
|
|
const batchItem = {
|
|
tokens: queuedDecode.tokens,
|
|
logits: queuedDecode.logits,
|
|
evaluationPriority: queuedDecode.evaluationPriority
|
|
};
|
|
batchItemToQueuedDecodeMap.set(batchItem, queuedDecode);
|
|
batchItemsList.push(batchItem);
|
|
}
|
|
let prioritizedItems;
|
|
try {
|
|
prioritizedItems = prioritizationStrategy({
|
|
items: batchItemsList,
|
|
size: this._batchSize
|
|
});
|
|
}
|
|
catch (err) {
|
|
this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err);
|
|
return null;
|
|
}
|
|
return prioritizedItems.map((prioritizedItem) => {
|
|
const queuedDecode = batchItemToQueuedDecodeMap.get(prioritizedItem.item);
|
|
if (queuedDecode == null)
|
|
throw new Error("Received invalid batch item. Make sure you keep the original object reference " +
|
|
"of the batch item on `item` on `PrioritizedBatchItem` in your custom prioritization strategy");
|
|
return {
|
|
queuedDecode,
|
|
processAmount: prioritizedItem.processAmount
|
|
};
|
|
});
|
|
};
|
|
const fitQueuedDecodesToABatch = (queuedDecodes, batchSize) => {
|
|
const currentBatchItems = [];
|
|
let currentBatchSize = 0;
|
|
let batchTokenSlotsLeft = batchSize;
|
|
for (const { queuedDecode, processAmount } of queuedDecodes) {
|
|
const resolvedProcessAmount = Math.min(processAmount <= 0 ? 1 : processAmount, queuedDecode.tokens.length, batchTokenSlotsLeft);
|
|
if (resolvedProcessAmount <= 0) {
|
|
if (batchTokenSlotsLeft === 0)
|
|
break;
|
|
continue;
|
|
}
|
|
batchTokenSlotsLeft -= resolvedProcessAmount;
|
|
currentBatchSize += resolvedProcessAmount;
|
|
currentBatchItems.push({
|
|
queuedDecode,
|
|
processAmount: resolvedProcessAmount
|
|
});
|
|
}
|
|
return {
|
|
currentBatchItems,
|
|
currentBatchSize
|
|
};
|
|
};
|
|
const decodeTokenBatchItems = async (batchItems, currentBatchSize) => {
|
|
const afterDecodeActions = [];
|
|
const queuedDecodesToDelete = new Set();
|
|
const currentQueuedDecodeItems = new Set();
|
|
if (currentBatchSize !== 0)
|
|
this._ctx.initBatch(currentBatchSize);
|
|
for (const { queuedDecode, processAmount } of batchItems) {
|
|
let batchLogitIndexes;
|
|
const tokensToProcess = queuedDecode.tokens.slice(0, processAmount);
|
|
const tokenIndexesWithLogitsToProcess = queuedDecode.logits.slice(0, processAmount)
|
|
.map((logit, index) => (logit ? index : undefined))
|
|
.filter((index) => index != undefined);
|
|
const numberOfOutputTokens = tokenIndexesWithLogitsToProcess.length;
|
|
TokenMeter.useTokens(queuedDecode.tokenMeter, Math.max(0, tokensToProcess.length - numberOfOutputTokens), "input");
|
|
TokenMeter.useTokens(queuedDecode.tokenMeter, numberOfOutputTokens, "output");
|
|
try {
|
|
batchLogitIndexes = this._ctx.addToBatch(queuedDecode.sequenceId, queuedDecode.firstTokenSequenceIndex, Uint32Array.from(tokensToProcess), Uint32Array.from(tokenIndexesWithLogitsToProcess));
|
|
}
|
|
catch (err) {
|
|
this._dispatchErrorForQueuedDecodesAndDequeue(new Set([queuedDecode]), err);
|
|
continue;
|
|
}
|
|
currentQueuedDecodeItems.add(queuedDecode);
|
|
if (queuedDecode.tokens.length === processAmount) {
|
|
queuedDecodesToDelete.add(queuedDecode);
|
|
afterDecodeActions.push({
|
|
queuedDecode,
|
|
batchLogitIndexes,
|
|
batchLogitTokenIndexes: tokenIndexesWithLogitsToProcess,
|
|
firstTokenIndex: queuedDecode.firstTokenSequenceIndex,
|
|
returnResults: true
|
|
});
|
|
}
|
|
else {
|
|
if (batchLogitIndexes.length > 0)
|
|
afterDecodeActions.push({
|
|
queuedDecode,
|
|
batchLogitIndexes,
|
|
batchLogitTokenIndexes: tokenIndexesWithLogitsToProcess,
|
|
firstTokenIndex: queuedDecode.firstTokenSequenceIndex
|
|
});
|
|
queuedDecode.tokens = queuedDecode.tokens.slice(processAmount);
|
|
queuedDecode.logits = queuedDecode.logits.slice(processAmount);
|
|
queuedDecode.firstTokenSequenceIndex += processAmount;
|
|
}
|
|
}
|
|
for (let i = 0; i < this._queuedDecodes.length; i++) {
|
|
const queuedDecode = this._queuedDecodes[i];
|
|
if (queuedDecodesToDelete.has(queuedDecode)) {
|
|
this._queuedDecodes.splice(i, 1);
|
|
this._queuedDecodeSequenceIds.delete(queuedDecode.sequenceId);
|
|
i--;
|
|
}
|
|
}
|
|
if (currentBatchSize !== 0) {
|
|
const allocationResult = this._threadSplitterConsumer?.getAllocationToConsume();
|
|
const [threadsToUse, consumerHandle] = allocationResult instanceof Promise
|
|
? await allocationResult ?? []
|
|
: allocationResult ?? [];
|
|
try {
|
|
if (threadsToUse != null)
|
|
this._ctx.setThreads(threadsToUse);
|
|
await this._ctx.decodeBatch();
|
|
consumerHandle?.dispose();
|
|
}
|
|
catch (err) {
|
|
consumerHandle?.dispose();
|
|
this._dispatchErrorForQueuedDecodesAndDequeue(currentQueuedDecodeItems, err);
|
|
return;
|
|
}
|
|
}
|
|
function finishAfterDecodeAction(action, mappedLogitValues) {
|
|
if (mappedLogitValues != null && mappedLogitValues.length > 0) {
|
|
if (queuedDecodeToMappedLogits.has(action.queuedDecode))
|
|
pushAll(queuedDecodeToMappedLogits.get(action.queuedDecode), mappedLogitValues);
|
|
else
|
|
queuedDecodeToMappedLogits.set(action.queuedDecode, mappedLogitValues);
|
|
}
|
|
if (action.returnResults != null) {
|
|
const [accept] = action.queuedDecode.response;
|
|
const mappedLogits = queuedDecodeToMappedLogits.get(action.queuedDecode) ?? [];
|
|
queuedDecodeToMappedLogits.delete(action.queuedDecode);
|
|
accept(mappedLogits);
|
|
}
|
|
}
|
|
const afterDecodeActionResults = afterDecodeActions.map((action) => {
|
|
if (action.batchLogitIndexes.length === 0) {
|
|
finishAfterDecodeAction(action);
|
|
return undefined;
|
|
}
|
|
const mappedLogitValues = [];
|
|
let promiseChain = undefined;
|
|
const batchLogitIndexes = action.batchLogitIndexes;
|
|
const batchLogitTokenIndexes = action.batchLogitTokenIndexes;
|
|
for (let i = 0; i < batchLogitIndexes.length; i++) {
|
|
const tokenIndex = batchLogitTokenIndexes[i];
|
|
const mappedValue = promiseChain != null
|
|
? promiseChain
|
|
.then(() => action.queuedDecode.logitDataMapper(batchLogitIndexes[i], tokenIndex + action.firstTokenIndex))
|
|
: action.queuedDecode.logitDataMapper(batchLogitIndexes[i], tokenIndex + action.firstTokenIndex);
|
|
if (mappedValue instanceof Promise) {
|
|
promiseChain = mappedValue;
|
|
mappedLogitValues.push(mappedValue
|
|
.then((value) => [tokenIndex + action.firstTokenIndex, value]));
|
|
}
|
|
else
|
|
mappedLogitValues.push([tokenIndex + action.firstTokenIndex, mappedValue]);
|
|
}
|
|
if (promiseChain != null)
|
|
return Promise.all(mappedLogitValues)
|
|
.then((resolvedMappedLogitValues) => finishAfterDecodeAction(action, resolvedMappedLogitValues));
|
|
finishAfterDecodeAction(action, mappedLogitValues);
|
|
return undefined;
|
|
});
|
|
await Promise.all(afterDecodeActionResults);
|
|
};
|
|
const prioritizationStrategy = resolvePrioritizationStrategy();
|
|
if (prioritizationStrategy == null)
|
|
return; // all queued items are rejected and dequeued when we get here
|
|
this._reserveThreads();
|
|
try {
|
|
while (shouldHaveAnotherLoop) {
|
|
const orderedQueuedDecodes = getOrderedQueuedDecodes(prioritizationStrategy);
|
|
if (orderedQueuedDecodes == null)
|
|
return; // all queued items are rejected and dequeued when we get here
|
|
const { currentBatchItems, currentBatchSize } = fitQueuedDecodesToABatch(orderedQueuedDecodes, this._batchSize);
|
|
let preventDisposalHandle;
|
|
try {
|
|
preventDisposalHandle = this._backendContextDisposeGuard.createPreventDisposalHandle();
|
|
}
|
|
catch (err) {
|
|
this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err);
|
|
return;
|
|
}
|
|
let decodeLock;
|
|
// this is a workaround to prevent Vulkan from crashing the process when decoding on multiple contexts in parallel
|
|
if (this._llama.gpu === "vulkan")
|
|
decodeLock = await acquireLock([decodeSyncWorkaround.vulkanLock, "decode"]);
|
|
try {
|
|
await decodeTokenBatchItems(currentBatchItems, currentBatchSize);
|
|
shouldHaveAnotherLoop = this._queuedDecodes.length > 0;
|
|
}
|
|
finally {
|
|
decodeLock?.dispose();
|
|
preventDisposalHandle.dispose();
|
|
}
|
|
}
|
|
}
|
|
finally {
|
|
this._scheduleToFreeReservedThreads();
|
|
}
|
|
});
|
|
}
|
|
/**
|
|
* Print the timings of token evaluation since that last print for this context.
|
|
*
|
|
* Requires the `performanceTracking` option to be enabled.
|
|
*
|
|
* > **Note:** it prints on the `LlamaLogLevel.info` level, so if you set the level of your `Llama` instance higher than that,
|
|
* it won't print anything.
|
|
*/
|
|
async printTimings() {
|
|
this._ensureNotDisposed();
|
|
if (!this._performanceTracking)
|
|
throw new UnsupportedError("Performance tracking is not enabled");
|
|
this._ctx.printTimings();
|
|
await new Promise((accept) => setTimeout(accept, 0)); // wait for the logs to finish printing
|
|
}
|
|
/** @internal */
|
|
async _decodeTokens({ sequenceId, firstTokenSequenceIndex, tokens, logits, evaluationPriority = defaultEvaluationPriority, tokenMeter }, logitDataMapper) {
|
|
return await new Promise((accept, reject) => {
|
|
this._queuedDecodes.push({
|
|
sequenceId,
|
|
tokens,
|
|
logits,
|
|
firstTokenSequenceIndex,
|
|
evaluationPriority,
|
|
tokenMeter,
|
|
response: [accept, reject],
|
|
logitDataMapper
|
|
});
|
|
this._queuedDecodeSequenceIds.add(sequenceId);
|
|
this._scheduleDecode();
|
|
});
|
|
}
|
|
/** @internal */
|
|
_reclaimUnusedSequenceId(sequenceId) {
|
|
if (this._disposed)
|
|
return;
|
|
void withLock([this, "context"], async () => {
|
|
if (this._disposed)
|
|
return;
|
|
this._ctx.disposeSequence(sequenceId);
|
|
this._unusedSequenceIds.push(sequenceId);
|
|
this._onReclaimUnusedSequenceId.dispatchEvent();
|
|
});
|
|
}
|
|
/** @internal */
|
|
_popSequenceId() {
|
|
if (this._unusedSequenceIds.length > 0)
|
|
return this._unusedSequenceIds.shift();
|
|
if (this._nextGeneratedSequenceId < this._totalSequences) {
|
|
const sequenceId = this._nextGeneratedSequenceId;
|
|
this._nextGeneratedSequenceId++;
|
|
return sequenceId;
|
|
}
|
|
return null;
|
|
}
|
|
/** @internal */
|
|
_scheduleDecode() {
|
|
if (this._dispatchDecodeScheduled || this._batchDispatchPending)
|
|
return;
|
|
this._dispatchDecodeScheduled = true;
|
|
const currentPendingBatchHandle = this._currentDispatchBatchHandle;
|
|
const dispatch = () => {
|
|
if (this._currentDispatchBatchHandle !== currentPendingBatchHandle)
|
|
return;
|
|
this.dispatchPendingBatch();
|
|
};
|
|
const dispatchSchedule = this._batchingOptions.dispatchSchedule;
|
|
if (this._queuedDecodeSequenceIds.size === this._totalSequences)
|
|
dispatch();
|
|
if (dispatchSchedule === "nextCycle") {
|
|
if (typeof setImmediate === "function")
|
|
setImmediate(dispatch);
|
|
else
|
|
setTimeout(dispatch, 0);
|
|
}
|
|
else if (typeof dispatchSchedule === "function")
|
|
dispatchSchedule(dispatch);
|
|
else {
|
|
if (typeof setImmediate === "function")
|
|
setImmediate(dispatch);
|
|
else
|
|
setTimeout(dispatch, 0);
|
|
}
|
|
}
|
|
/** @internal */
|
|
_dispatchErrorForQueuedDecodesAndDequeue(queuedDecodes, err) {
|
|
for (const pendingDecode of queuedDecodes) {
|
|
const [, reject] = pendingDecode.response;
|
|
reject(err);
|
|
}
|
|
for (let i = 0; i < this._queuedDecodes.length; i++) {
|
|
const item = this._queuedDecodes[i];
|
|
if (queuedDecodes.has(item)) {
|
|
this._queuedDecodes.splice(i, 1);
|
|
this._queuedDecodeSequenceIds.delete(item.sequenceId);
|
|
i--;
|
|
}
|
|
}
|
|
}
|
|
/** @internal */
|
|
_ensureNotDisposed() {
|
|
if (this._disposed)
|
|
throw new DisposedError();
|
|
}
|
|
/** @internal */
|
|
async _setLora({ filePath, scale }) {
|
|
const lora = await this._model._getOrLoadLora(filePath);
|
|
this._ctx.setLora(lora, scale ?? defaultLoraScale);
|
|
if (!this._loraAdapters.has(lora)) {
|
|
this._loraAdapters.add(lora);
|
|
lora.usages++;
|
|
}
|
|
}
|
|
/** @internal */
|
|
_reserveThreads() {
|
|
clearTimeout(this._freeReservedThreadsTimeout);
|
|
delete this._freeReservedThreadsTimeout;
|
|
if (this._threadSplitterConsumer != null)
|
|
return;
|
|
this._threadSplitterConsumer = this._llama._threadsSplitter.createConsumer(this._idealThreads, this._minThreads);
|
|
}
|
|
/** @internal */
|
|
_freeReservedThreads() {
|
|
clearTimeout(this._freeReservedThreadsTimeout);
|
|
delete this._freeReservedThreadsTimeout;
|
|
if (this._threadSplitterConsumer == null)
|
|
return;
|
|
this._threadSplitterConsumer.dispose();
|
|
delete this._threadSplitterConsumer;
|
|
}
|
|
/** @internal */
|
|
_scheduleToFreeReservedThreads() {
|
|
if (this._threadSplitterConsumer == null)
|
|
return;
|
|
clearTimeout(this._freeReservedThreadsTimeout);
|
|
this._freeReservedThreadsTimeout = setTimeout(this._freeReservedThreads, 0);
|
|
}
|
|
/** @internal */
|
|
static async _create(options, { _model }) {
|
|
const sequences = options.sequences ?? getDefaultContextSequences();
|
|
const flashAttention = _model.flashAttentionSupported
|
|
? Boolean(options.flashAttention ?? _model.defaultContextFlashAttention)
|
|
: false;
|
|
const swaFullCache = options.swaFullCache ?? _model.defaultContextSwaFullCache;
|
|
const loraOptions = typeof options.lora === "string"
|
|
? { adapters: [{ filePath: options.lora }] }
|
|
: options.lora;
|
|
let failedCreationRetries = options.failedCreationRemedy === false
|
|
? 0
|
|
: Math.max(0, options.failedCreationRemedy?.retries ?? defaultFailedCreationRemedy.retries);
|
|
const failedCreationAutoContextSizeShrink = options.failedCreationRemedy === false
|
|
? 0
|
|
: options.failedCreationRemedy?.autoContextSizeShrink ?? defaultFailedCreationRemedy.autoContextSizeShrink;
|
|
let contextSize = await _model.fileInsights.configurationResolver.resolveContextContextSize(options.contextSize, {
|
|
batchSize: options.batchSize,
|
|
sequences: sequences,
|
|
modelGpuLayers: _model.gpuLayers,
|
|
modelTrainContextSize: _model.trainContextSize,
|
|
flashAttention,
|
|
swaFullCache,
|
|
getVramState: () => _model._llama._vramOrchestrator.getMemoryState(),
|
|
llamaGpu: _model._llama.gpu,
|
|
ignoreMemorySafetyChecks: options.ignoreMemorySafetyChecks,
|
|
isEmbeddingContext: options._embeddings
|
|
});
|
|
const minContextSize = options.contextSize === "auto"
|
|
? shrinkRetriesMinContextSize
|
|
: (typeof options.contextSize === "object" && typeof options.contextSize.min === "number")
|
|
? options.contextSize.min
|
|
: typeof options.contextSize === "number"
|
|
? options.contextSize
|
|
: shrinkRetriesMinContextSize;
|
|
const { createSignal } = options;
|
|
async function createContext(contextSize) {
|
|
const batchSize = options.batchSize ?? getDefaultContextBatchSize({ contextSize, sequences });
|
|
const resourceRequirementsEstimation = _model.fileInsights.estimateContextResourceRequirements({
|
|
contextSize,
|
|
sequences,
|
|
isEmbeddingContext: options._embeddings,
|
|
modelGpuLayers: _model.gpuLayers,
|
|
batchSize,
|
|
flashAttention,
|
|
swaFullCache
|
|
});
|
|
const context = new LlamaContext({ _model }, { ...options, contextSize, batchSize, sequences, flashAttention, swaFullCache });
|
|
const contextCreationVramReservation = options.ignoreMemorySafetyChecks
|
|
? null
|
|
: _model._llama._vramOrchestrator.reserveMemory(resourceRequirementsEstimation.gpuVram);
|
|
const contextCreationRamReservation = options.ignoreMemorySafetyChecks
|
|
? null
|
|
: _model._llama._vramOrchestrator.reserveMemory(resourceRequirementsEstimation.cpuRam);
|
|
try {
|
|
if (createSignal?.aborted)
|
|
throw createSignal.reason;
|
|
const contextLoaded = await context._ctx.init();
|
|
if (createSignal?.aborted) {
|
|
if (contextLoaded)
|
|
await context._ctx.dispose();
|
|
throw createSignal.reason;
|
|
}
|
|
else if (!contextLoaded)
|
|
throw new Error("Failed to create context");
|
|
contextCreationVramReservation?.dispose?.();
|
|
contextCreationRamReservation?.dispose?.();
|
|
if (loraOptions != null && loraOptions.adapters.length > 0) {
|
|
let loadedAdapters = 0;
|
|
for (const adapter of loraOptions.adapters) {
|
|
try {
|
|
await context._setLora({
|
|
filePath: adapter.filePath,
|
|
scale: adapter.scale
|
|
});
|
|
loadedAdapters++;
|
|
try {
|
|
loraOptions.onLoadProgress?.(loadedAdapters / loraOptions.adapters.length);
|
|
}
|
|
catch (err) {
|
|
console.error(err);
|
|
}
|
|
}
|
|
catch (err) {
|
|
await context.dispose();
|
|
throw err;
|
|
}
|
|
if (createSignal?.aborted) {
|
|
await context.dispose();
|
|
throw createSignal.reason;
|
|
}
|
|
}
|
|
}
|
|
else if (loraOptions?.onLoadProgress != null) {
|
|
try {
|
|
loraOptions.onLoadProgress(1);
|
|
}
|
|
catch (err) {
|
|
console.error(err);
|
|
}
|
|
}
|
|
return context;
|
|
}
|
|
finally {
|
|
contextCreationVramReservation?.dispose?.();
|
|
contextCreationRamReservation?.dispose?.();
|
|
}
|
|
}
|
|
while (failedCreationRetries >= 0) {
|
|
try {
|
|
return await createContext(contextSize);
|
|
}
|
|
catch (err) {
|
|
if (failedCreationRetries === 0 || (createSignal?.aborted && err === createSignal.reason))
|
|
throw err;
|
|
failedCreationRetries--;
|
|
let newContextSize = typeof failedCreationAutoContextSizeShrink === "number"
|
|
? Math.floor(contextSize * (1 - failedCreationAutoContextSizeShrink))
|
|
: Math.floor(failedCreationAutoContextSizeShrink(contextSize));
|
|
if (!Number.isFinite(newContextSize))
|
|
throw err;
|
|
if (newContextSize < minContextSize)
|
|
newContextSize = minContextSize;
|
|
if (newContextSize >= contextSize)
|
|
throw err;
|
|
contextSize = newContextSize;
|
|
}
|
|
}
|
|
throw new Error("Failed to create context");
|
|
}
|
|
}
|
|
export class LlamaContextSequence {
|
|
/** @internal */ _sequenceId;
|
|
/** @internal */ _gcRegistry;
|
|
/** @internal */ _context;
|
|
/** @internal */ _contextShift;
|
|
/** @internal */ _tokenPredictor;
|
|
/** @internal */ _tokenMeter;
|
|
/** @internal */ _disposeAggregator = new DisposeAggregator();
|
|
/** @internal */ _lock = {};
|
|
/** @internal */ _resetTokenPredictor = false;
|
|
/** @internal */ _tokenPredictorOwner = {};
|
|
/** @internal */ _contextTokens = [];
|
|
/** @internal */ _nextTokenIndex = 0;
|
|
/** @internal */ _loadedTokenPredictions = [];
|
|
/** @internal */ _usedTokenPredictions = 0;
|
|
/** @internal */ _unusedTokenPredictions = 0;
|
|
/** @internal */ _validatedTokenPredictions = 0;
|
|
/** @internal */ _refutedTokenPredictions = 0;
|
|
/** @internal */ _disposed = false;
|
|
onDispose = new EventRelay();
|
|
constructor({ sequenceId, context, tokenMeter, contextShift, tokenPredictor }) {
|
|
this._sequenceId = sequenceId;
|
|
this._context = context;
|
|
this._tokenMeter = tokenMeter ?? new TokenMeter();
|
|
this._contextShift = contextShift;
|
|
this._tokenPredictor = tokenPredictor;
|
|
this._gcRegistry = new FinalizationRegistry(this._context._reclaimUnusedSequenceId);
|
|
this._gcRegistry.register(this, sequenceId);
|
|
this._disposeAggregator.add(() => this._gcRegistry.unregister(this));
|
|
this._disposeAggregator.add(this.onDispose.dispatchEvent);
|
|
this._disposeAggregator.add(this.model.onDispose.createListener(disposeContextSequenceIfReferenced.bind(null, new WeakRef(this))));
|
|
this._disposeAggregator.add(() => {
|
|
this._context._reclaimUnusedSequenceId(this._sequenceId);
|
|
});
|
|
if (this._tokenPredictor != null)
|
|
this._disposeAggregator.add(this._tokenPredictor);
|
|
}
|
|
dispose() {
|
|
if (this._disposed)
|
|
return;
|
|
this._disposeAggregator.dispose();
|
|
this._contextTokens.length = 0;
|
|
this._disposed = true;
|
|
}
|
|
/** @hidden */
|
|
[Symbol.dispose]() {
|
|
return this.dispose();
|
|
}
|
|
get disposed() {
|
|
return this._disposed;
|
|
}
|
|
get context() {
|
|
return this._context;
|
|
}
|
|
get model() {
|
|
return this._context.model;
|
|
}
|
|
/** The maximum number of tokens that the sequence state can hold */
|
|
get contextSize() {
|
|
return this._context.contextSize;
|
|
}
|
|
/** The index where the next evaluated token will be placed in the context */
|
|
get nextTokenIndex() {
|
|
return this._nextTokenIndex - this._loadedTokenPredictions.length;
|
|
}
|
|
/** The current context state tokens */
|
|
get contextTokens() {
|
|
if (this._loadedTokenPredictions.length === 0)
|
|
return this._contextTokens.slice();
|
|
return this._contextTokens.slice(0, -this._loadedTokenPredictions.length);
|
|
}
|
|
get tokenMeter() {
|
|
return this._tokenMeter;
|
|
}
|
|
/**
|
|
* The token predictor used when creating this sequence.
|
|
*/
|
|
get tokenPredictor() {
|
|
return this._tokenPredictor;
|
|
}
|
|
/**
|
|
* Get the index of the first token in the KV cache.
|
|
*
|
|
* If you remove any tokens from the state that come before this index,
|
|
* no cached prefix tokens evaluation state will be used for the next evaluation.
|
|
*
|
|
* For example, if `stateCellsStartIndex` is `10` and you remove the range `{start: 11, end: 16}`
|
|
* then the cached state for range `0-10` will be used in the next evaluation,
|
|
* but if you remove the range `{start: 10, end: 16}` (or `{start: 9, end: 16}`) then the cached state will not be used at all
|
|
* and will be re-evaluated in the next evaluation.
|
|
*
|
|
* This index can be greater than `0` only when SWA (Sliding Window Attention) is used (only on supported models).
|
|
*
|
|
* When SWA is used, this index will usually be `Math.max(-1, .nextTokenIndex - .model.fileInsights.swaSize)` or larger.
|
|
*
|
|
* When the KV cache is empty, this index will be `-1`.
|
|
*
|
|
* You can disable SWA by setting the `swaFullCache` option to `true` when creating a context.
|
|
*/
|
|
get stateCellsStartIndex() {
|
|
this._ensureNotDisposed();
|
|
return this._context._ctx.getSequenceKvCacheMinPosition(this._sequenceId);
|
|
}
|
|
/**
|
|
* Statistics of token predictions using the sequence's `tokenPredictor`.
|
|
*
|
|
* The statistics change only when token prediction is used in this sequence.
|
|
*
|
|
* `validated` + `refuted` = total number of evaluated predictions.
|
|
*
|
|
* Prefer using `validated` and `refuted` to evaluate the effectiveness of token prediction.
|
|
*/
|
|
get tokenPredictions() {
|
|
return {
|
|
used: this._usedTokenPredictions,
|
|
unused: this._unusedTokenPredictions,
|
|
validated: this._validatedTokenPredictions,
|
|
refuted: this._refutedTokenPredictions
|
|
};
|
|
}
|
|
get isLoadedToMemory() {
|
|
return !this._disposed;
|
|
}
|
|
compareContextTokens(tokens) {
|
|
for (let i = 0; i < this._contextTokens.length - this._loadedTokenPredictions.length; i++) {
|
|
if (compareTokens(this._contextTokens[i], tokens[i]))
|
|
continue;
|
|
return {
|
|
firstDifferentIndex: i
|
|
};
|
|
}
|
|
return {
|
|
firstDifferentIndex: this._contextTokens.length - this._loadedTokenPredictions.length
|
|
};
|
|
}
|
|
/**
|
|
* Erase parts of the context state to align it with the given tokens.
|
|
*
|
|
* If the given tokens do not align with the current context state, the context state will be erased to align with the given tokens.
|
|
*
|
|
* To find the first different token index between the context state and the given tokens, access the `nextTokenIndex` property.
|
|
*
|
|
* If `allowShift` is `true` (the default), shifting tokens may happen to align the context state with the given tokens,
|
|
* which incurs token evaluation of the shifted tokens.
|
|
*/
|
|
async adaptStateToTokens(tokens, allowShift = true) {
|
|
const modelSupportsShifting = !this.model.fileInsights.isRecurrent &&
|
|
this.model.fileInfo.metadata?.general?.architecture !== GgufArchitectureType.deepseek2;
|
|
if (!modelSupportsShifting || !allowShift) {
|
|
const { firstDifferentIndex } = this.compareContextTokens(tokens);
|
|
if (firstDifferentIndex < this.nextTokenIndex)
|
|
await this._eraseContextTokenRanges([{
|
|
start: firstDifferentIndex,
|
|
end: this._nextTokenIndex
|
|
}]);
|
|
return;
|
|
}
|
|
const eraseRanges = [];
|
|
let tokensIndex = 0;
|
|
let differentTokenIndex = undefined;
|
|
for (let i = 0; i < this._contextTokens.length - this._loadedTokenPredictions.length && tokensIndex < tokens.length; i++) {
|
|
if (compareTokens(this._contextTokens[i], tokens[tokensIndex])) {
|
|
if (differentTokenIndex != null) {
|
|
eraseRanges.push({
|
|
start: differentTokenIndex,
|
|
end: i
|
|
});
|
|
differentTokenIndex = undefined;
|
|
}
|
|
tokensIndex++;
|
|
continue;
|
|
}
|
|
if (differentTokenIndex == null)
|
|
differentTokenIndex = i;
|
|
}
|
|
if (differentTokenIndex != null)
|
|
eraseRanges.push({
|
|
start: differentTokenIndex,
|
|
end: this._nextTokenIndex
|
|
});
|
|
if (eraseRanges.length > 0)
|
|
await this._eraseContextTokenRanges(eraseRanges);
|
|
}
|
|
/**
|
|
* Clear the history of the sequence.
|
|
*/
|
|
async clearHistory() {
|
|
this._ensureNotDisposed();
|
|
await this._eraseContextTokenRanges([{ start: 0, end: this._nextTokenIndex }]);
|
|
}
|
|
/**
|
|
* Erase context tokens in the provided ranges to free up space for new tokens to be generated.
|
|
* The start of each range is inclusive, and the end of each range is exclusive.
|
|
* For example, the range `{start: 0, end: 1}` will remove the token at the `0` index only.
|
|
*/
|
|
eraseContextTokenRanges(ranges) {
|
|
return this._eraseContextTokenRanges(ranges);
|
|
}
|
|
/** @internal */
|
|
async _eraseContextTokenRanges(ranges, { canResetTokenPredictor = true, canRemovePredictionTokens = true, skipLock = false } = {}) {
|
|
this._ensureNotDisposed();
|
|
let awaitPromise;
|
|
await withLock([this._context, "context"], async () => {
|
|
this._ensureNotDisposed();
|
|
if (ranges.length === 0)
|
|
return;
|
|
// if the deletion fails, we'll have to dispose the sequence and fill it up again
|
|
let deletionSuccessful = true;
|
|
const resolvedRanges = ranges
|
|
.map(({ start, end }) => {
|
|
if (start === end)
|
|
return null;
|
|
if (start > end)
|
|
[start, end] = [end, start];
|
|
if (end > this._nextTokenIndex)
|
|
end = this._nextTokenIndex;
|
|
if (start >= this._nextTokenIndex)
|
|
return null;
|
|
return { start, end };
|
|
})
|
|
.filter((range) => range != null)
|
|
.sort((a, b) => a.start - b.start)
|
|
.reduce((ranges, range) => {
|
|
if (ranges.length === 0)
|
|
return [range];
|
|
const lastRange = ranges[ranges.length - 1];
|
|
if (lastRange.end >= range.start) {
|
|
lastRange.end = Math.max(lastRange.end, range.end);
|
|
return ranges;
|
|
}
|
|
ranges.push(range);
|
|
return ranges;
|
|
}, []);
|
|
const minKvCachePosition = (this._contextTokens.length === 0 && this._loadedTokenPredictions.length === 0)
|
|
? 0
|
|
: Math.max(0, this._context._ctx.getSequenceKvCacheMinPosition(this._sequenceId));
|
|
if (resolvedRanges[0] != null && resolvedRanges[0].start <= minKvCachePosition)
|
|
// we have to drop the cache and reevaluate the sequence due to missing KV cache
|
|
deletionSuccessful = false;
|
|
const tokenPredictionsToRemove = (resolvedRanges.length > 0 && canRemovePredictionTokens)
|
|
? this._loadedTokenPredictions.length
|
|
: 0;
|
|
if (tokenPredictionsToRemove > 0) {
|
|
const startDeleteIndex = this._nextTokenIndex - this._loadedTokenPredictions.length;
|
|
const lastDeleteRange = resolvedRanges[resolvedRanges.length - 1];
|
|
if (lastDeleteRange.end >= startDeleteIndex)
|
|
lastDeleteRange.end = this._nextTokenIndex;
|
|
else
|
|
resolvedRanges.push({ start: startDeleteIndex, end: this._nextTokenIndex });
|
|
if (canResetTokenPredictor)
|
|
await this._abortTokenPredictor(true);
|
|
}
|
|
let removedTokens = 0;
|
|
let lastDeleteRangeEndPos = null;
|
|
for (const range of resolvedRanges) {
|
|
this._contextTokens.splice(range.start - removedTokens, range.end - range.start);
|
|
if (deletionSuccessful)
|
|
deletionSuccessful &&= this._context._ctx.removeTokenCellsFromSequence(this._sequenceId, range.start, range.end);
|
|
if (deletionSuccessful && lastDeleteRangeEndPos != null && removedTokens > 0 && lastDeleteRangeEndPos !== range.start) {
|
|
this._context._ctx.shiftSequenceTokenCells(this._sequenceId, lastDeleteRangeEndPos, range.start, -removedTokens);
|
|
const shiftedTokens = range.start - lastDeleteRangeEndPos;
|
|
this._tokenMeter.useTokens(shiftedTokens, "input");
|
|
}
|
|
removedTokens += range.end - range.start;
|
|
lastDeleteRangeEndPos = range.end;
|
|
}
|
|
if (tokenPredictionsToRemove > 0)
|
|
this._loadedTokenPredictions.splice(0, tokenPredictionsToRemove);
|
|
if (deletionSuccessful && lastDeleteRangeEndPos != null && removedTokens > 0 &&
|
|
lastDeleteRangeEndPos !== this._nextTokenIndex) {
|
|
this._context._ctx.shiftSequenceTokenCells(this._sequenceId, lastDeleteRangeEndPos, this._nextTokenIndex, -removedTokens);
|
|
const shiftedTokens = this._nextTokenIndex - lastDeleteRangeEndPos;
|
|
this._tokenMeter.useTokens(shiftedTokens, "input");
|
|
}
|
|
this._nextTokenIndex -= removedTokens;
|
|
if (canResetTokenPredictor && removedTokens > 0)
|
|
await this._abortTokenPredictor(true);
|
|
if (deletionSuccessful)
|
|
return;
|
|
const newSequenceTokens = this._contextTokens.slice();
|
|
this._nextTokenIndex = 0;
|
|
this._context._ctx.disposeSequence(this._sequenceId);
|
|
// wait for the evaluation outside the "context" lock to avoid deadlocks
|
|
awaitPromise = this.evaluateWithoutGeneratingNewTokens(newSequenceTokens, { _skipLock: skipLock });
|
|
});
|
|
if (awaitPromise != null)
|
|
await awaitPromise;
|
|
}
|
|
/**
|
|
* Evaluate the provided tokens into the context sequence, and continue generating new tokens on iterator iterations.
|
|
*
|
|
* This method uses the token predictor (when provided) to generate new tokens faster.
|
|
*/
|
|
async *evaluate(tokens, options = {}) {
|
|
const iterator = this.evaluateWithMetadata(tokens, {}, options);
|
|
let iterateInput = undefined;
|
|
try {
|
|
while (true) {
|
|
const { value, done } = await iterator.next(iterateInput);
|
|
if (done)
|
|
return;
|
|
iterateInput = yield value.token;
|
|
}
|
|
}
|
|
finally {
|
|
await iterator.return();
|
|
}
|
|
}
|
|
/**
|
|
* Like {@link evaluate `.evaluate(...)`}, but with additional metadata for each generated token.
|
|
*
|
|
* Configure the additional metadata options to choose which metadata to include.
|
|
*/
|
|
evaluateWithMetadata(tokens, metadata, options = {}) {
|
|
const { temperature = 0, minP = 0, topK = 40, topP = 0.95, seed, grammarEvaluationState, repeatPenalty, tokenBias, evaluationPriority = defaultEvaluationPriority, contextShift: { size: contextShiftSize = this._contextShift.size, strategy: contextShiftStrategy = this._contextShift.strategy } = {}, yieldEogToken = false, _noSampling = false } = options;
|
|
if (this._tokenPredictor != null && !_noSampling && tokens.length > 0)
|
|
return this._speculativeEvaluate(tokens, metadata, {
|
|
temperature,
|
|
minP,
|
|
topK,
|
|
topP,
|
|
seed,
|
|
grammarEvaluationState,
|
|
repeatPenalty,
|
|
tokenBias,
|
|
evaluationPriority,
|
|
contextShiftOptions: {
|
|
size: contextShiftSize,
|
|
strategy: contextShiftStrategy
|
|
},
|
|
yieldEogToken,
|
|
tokenPredictor: this._tokenPredictor
|
|
});
|
|
return this._evaluate(tokens, metadata, {
|
|
temperature,
|
|
minP,
|
|
topK,
|
|
topP,
|
|
seed,
|
|
grammarEvaluationState,
|
|
repeatPenalty,
|
|
tokenBias,
|
|
evaluationPriority,
|
|
contextShiftOptions: {
|
|
size: contextShiftSize,
|
|
strategy: contextShiftStrategy
|
|
},
|
|
yieldEogToken,
|
|
_noSampling
|
|
});
|
|
}
|
|
/**
|
|
* Evaluate the provided tokens into the context sequence without generating new tokens.
|
|
*/
|
|
async evaluateWithoutGeneratingNewTokens(tokens, options = {}) {
|
|
const { evaluationPriority = defaultEvaluationPriority, contextShift: { size: contextShiftSize = this._contextShift.size, strategy: contextShiftStrategy = this._contextShift.strategy } = {}, _skipLock = false } = options;
|
|
const iterator = this._evaluate(tokens, {}, {
|
|
generateNewTokens: false,
|
|
evaluationPriority,
|
|
contextShiftOptions: {
|
|
size: contextShiftSize,
|
|
strategy: contextShiftStrategy
|
|
},
|
|
_skipLock
|
|
});
|
|
const predictorAlignmentPromise = this.tokenPredictor == null
|
|
? undefined
|
|
: this._tokenPredictor?.reset({
|
|
stateTokens: [...this._contextTokens, ...tokens],
|
|
evaluateOptions: {
|
|
evaluationPriority,
|
|
contextShift: {
|
|
size: contextShiftSize,
|
|
strategy: contextShiftStrategy
|
|
}
|
|
},
|
|
targetSequence: this
|
|
});
|
|
if (predictorAlignmentPromise != null) {
|
|
this._tokenPredictorOwner = {};
|
|
this._resetTokenPredictor = false;
|
|
}
|
|
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
|
for await (const token of iterator) {
|
|
// Array.from doesn't work with async generators, so we have to iterate over the generator
|
|
}
|
|
await iterator.return();
|
|
if (predictorAlignmentPromise != null)
|
|
await predictorAlignmentPromise;
|
|
}
|
|
/**
|
|
* Evaluate the provided tokens into the context sequence with custom options for each token.
|
|
*
|
|
* This method allows for more precise control of the generation process.
|
|
*
|
|
* A next token will be generated for a given token only if any of the `generateNext` options for it are used.
|
|
*
|
|
* To generate more tokens after this method finishes,
|
|
* use it again with token(s) you selected to add to the context from the previous evaluation.
|
|
*
|
|
* This method doesn't use the token predictor (when provided) since it cannot predict which tokens are actually needed.
|
|
* Use the `evaluate` method when you need to use token prediction.
|
|
* @returns An array where for each token in the input array, there can be an output item at the same index in the output array.
|
|
* For indexes that have no output, there won't be any value at the corresponding index in the output array.
|
|
*
|
|
* It's recommended to iterate from `0` up to the length of the input array to check the results in the output array.
|
|
*/
|
|
async controlledEvaluate(input, options) {
|
|
const { evaluationPriority = defaultEvaluationPriority, contextShift: { size: contextShiftSize = this._contextShift.size, strategy: contextShiftStrategy = this._contextShift.strategy } = {} } = options ?? {};
|
|
const contextShiftOptions = {
|
|
size: contextShiftSize,
|
|
strategy: contextShiftStrategy
|
|
};
|
|
this._ensureNotDisposed();
|
|
if (input.length === 0)
|
|
return [];
|
|
await this._abortTokenPredictor();
|
|
const sampler = new LlamaSampler(this.model);
|
|
const onTokenResult = safeEventCallback(options?.onTokenResult);
|
|
const logitsArray = [];
|
|
const resolvedTokens = input.map((item, index) => {
|
|
if (item instanceof Array) {
|
|
const [token, options] = item;
|
|
const generateNext = options?.generateNext ?? {};
|
|
if (generateNext.probabilities === true || generateNext.confidence === true || generateNext.token === true)
|
|
logitsArray[index] = true;
|
|
return token;
|
|
}
|
|
return item;
|
|
});
|
|
const evaluatorLock = await acquireLock([this._lock, "evaluate"]);
|
|
try {
|
|
return await this._decodeTokens(resolvedTokens, logitsArray, evaluationPriority, this._tokenMeter, contextShiftOptions, async (batchLogitIndex, tokenIndex) => {
|
|
const inputToken = input[tokenIndex];
|
|
const inputOptions = inputToken instanceof Array
|
|
? (inputToken[1] ?? {})
|
|
: {};
|
|
const generateNext = inputOptions.generateNext;
|
|
if (generateNext == null || ((generateNext.probabilities == null || !generateNext.probabilities) &&
|
|
(generateNext.token == null || !generateNext.token) &&
|
|
(generateNext.confidence == null || !generateNext.confidence)))
|
|
return undefined;
|
|
const sampleOptions = generateNext.options ?? {};
|
|
const samplerConfig = this._resolveSamplerConfig({
|
|
temperature: sampleOptions.temperature,
|
|
minP: sampleOptions.minP,
|
|
topK: sampleOptions.topK,
|
|
topP: sampleOptions.topP,
|
|
seed: sampleOptions.seed,
|
|
repeatPenalty: sampleOptions.repeatPenalty,
|
|
tokenBias: sampleOptions.tokenBias
|
|
});
|
|
return await withLock([sampler, "sample"], async () => {
|
|
if (sampler.disposed)
|
|
return undefined;
|
|
sampler.applyConfig(samplerConfig);
|
|
const [token, probabilities, confidence] = await this._context._ctx.sampleToken(batchLogitIndex, sampler._sampler, !!generateNext.probabilities, !!generateNext.confidence);
|
|
const output = {
|
|
next: {}
|
|
};
|
|
if (generateNext.token)
|
|
output.next.token = token === -1
|
|
? null
|
|
: (token ?? null);
|
|
if (confidence != null)
|
|
output.next.confidence = confidence;
|
|
if (probabilities != null)
|
|
output.next.probabilities = reviveTokenProbabilities(probabilities);
|
|
onTokenResult?.(tokenIndex, output);
|
|
return output;
|
|
});
|
|
});
|
|
}
|
|
finally {
|
|
evaluatorLock.dispose();
|
|
void withLock([sampler, "sample"], sampler.asyncDispose);
|
|
}
|
|
}
|
|
/* eslint-disable @stylistic/max-len */
|
|
/**
|
|
* Save the current context sequence evaluation state to a file.
|
|
* @see [Saving and restoring a context sequence evaluation state](https://node-llama-cpp.withcat.ai/guide/chat-session#save-and-restore-with-context-sequence-state)
|
|
*/
|
|
async saveStateToFile(filePath) {
|
|
/* eslint-enable @stylistic/max-len */
|
|
this._ensureNotDisposed();
|
|
const resolvedPath = path.resolve(process.cwd(), filePath);
|
|
const evaluatorLock = await acquireLock([this._lock, "evaluate"]);
|
|
const contextLock = await acquireLock([this._context, "context"]);
|
|
try {
|
|
this._ensureNotDisposed();
|
|
const fileSize = await this._context._ctx.saveSequenceStateToFile(resolvedPath, this._sequenceId, Uint32Array.from(this.contextTokens));
|
|
return { fileSize };
|
|
}
|
|
finally {
|
|
contextLock.dispose();
|
|
evaluatorLock.dispose();
|
|
}
|
|
}
|
|
/* eslint-disable @stylistic/max-len */
|
|
/**
|
|
* Load a context sequence evaluation state from a file.
|
|
*
|
|
* Trying to load a state file with a longer context size than the current sequence's context size will fail and throw an error.
|
|
*
|
|
* You must ensure that the file was created from the exact same model, otherwise, using this function may crash the process.
|
|
* @see [Saving and restoring a context sequence evaluation state](https://node-llama-cpp.withcat.ai/guide/chat-session#save-and-restore-with-context-sequence-state)
|
|
*/
|
|
async loadStateFromFile(filePath, acceptRisk) {
|
|
/* eslint-enable @stylistic/max-len */
|
|
if (!acceptRisk.acceptRisk)
|
|
throw new Error("The `acceptRisk` option must be set to `true` to use this feature");
|
|
this._ensureNotDisposed();
|
|
const resolvedPath = path.resolve(process.cwd(), filePath);
|
|
const evaluatorLock = await acquireLock([this._lock, "evaluate"]);
|
|
const contextLock = await acquireLock([this._context, "context"]);
|
|
try {
|
|
this._ensureNotDisposed();
|
|
this._tokenPredictorOwner = {};
|
|
await this._abortTokenPredictor(true);
|
|
this._ensureNotDisposed();
|
|
this._loadedTokenPredictions.length = 0;
|
|
this._nextTokenIndex = 0;
|
|
this._contextTokens = [];
|
|
const tokens = Array.from(await this._context._ctx.loadSequenceStateFromFile(resolvedPath, this._sequenceId, this.contextSize));
|
|
if (tokens.length > this.contextSize) {
|
|
this._context._ctx.disposeSequence(this._sequenceId);
|
|
throw new Error("The given state file is too large for the current context size");
|
|
}
|
|
this._contextTokens = tokens;
|
|
this._nextTokenIndex = tokens.length;
|
|
this._loadedTokenPredictions.length = 0;
|
|
}
|
|
finally {
|
|
contextLock.dispose();
|
|
evaluatorLock.dispose();
|
|
}
|
|
}
|
|
/** @internal */
|
|
async *_evaluate(tokens, metadata, { temperature, minP, topK, topP, seed, grammarEvaluationState, repeatPenalty, tokenBias, evaluationPriority = defaultEvaluationPriority, generateNewTokens = true, contextShiftOptions, yieldEogToken = false, _noSampling = false, _skipLock = false }) {
|
|
this._ensureNotDisposed();
|
|
let evalTokens = tokens;
|
|
if (evalTokens.length === 0)
|
|
return;
|
|
await this._abortTokenPredictor(false, true);
|
|
const sampleProbabilities = metadata.probabilities === true;
|
|
const sampleConfidence = metadata.confidence === true;
|
|
const sampler = new LlamaSampler(this.model);
|
|
try {
|
|
while (true) {
|
|
this._ensureNotDisposed();
|
|
const evaluatorLock = _skipLock
|
|
? undefined
|
|
: await acquireLock([this._lock, "evaluate"]);
|
|
let nextToken;
|
|
const yieldRes = {};
|
|
try {
|
|
const logitsArray = [];
|
|
if (generateNewTokens)
|
|
logitsArray[evalTokens.length - 1] = true;
|
|
// Evaluate to get the next token.
|
|
const decodeResult = await this._decodeTokens(evalTokens, logitsArray, evaluationPriority, this._tokenMeter, contextShiftOptions, (batchLogitIndex) => {
|
|
if (_noSampling)
|
|
return null;
|
|
const samplerConfig = this._resolveSamplerConfig({
|
|
temperature,
|
|
minP,
|
|
topK,
|
|
topP,
|
|
seed,
|
|
grammarEvaluationState,
|
|
repeatPenalty,
|
|
tokenBias
|
|
});
|
|
return withLock([sampler, "sample"], async () => {
|
|
if (sampler.disposed)
|
|
return null;
|
|
sampler.applyConfig(samplerConfig);
|
|
if (sampleProbabilities || sampleConfidence)
|
|
return this._context._ctx.sampleToken(batchLogitIndex, sampler._sampler, sampleProbabilities, sampleConfidence);
|
|
else
|
|
return this._context._ctx.sampleToken(batchLogitIndex, sampler._sampler);
|
|
});
|
|
});
|
|
const lastDecodeResult = decodeResult[evalTokens.length - 1];
|
|
if (lastDecodeResult instanceof Array) {
|
|
const [token, probabilities, confidence] = lastDecodeResult;
|
|
nextToken = token;
|
|
if (probabilities != null)
|
|
yieldRes.probabilities = reviveTokenProbabilities(probabilities);
|
|
if (confidence != null)
|
|
yieldRes.confidence = confidence;
|
|
}
|
|
else
|
|
nextToken = lastDecodeResult;
|
|
if (nextToken === -1)
|
|
throw new Error("Failed to sample next token");
|
|
if (nextToken == null)
|
|
return;
|
|
// the model finished generating text
|
|
if (!yieldEogToken && this._context.model.isEogToken(nextToken))
|
|
break;
|
|
}
|
|
finally {
|
|
evaluatorLock?.dispose();
|
|
}
|
|
yieldRes.token = nextToken;
|
|
const replacementToken = yield yieldRes;
|
|
// set the tokens for the next evaluation
|
|
if (replacementToken instanceof Array)
|
|
evalTokens = replacementToken.slice();
|
|
else if (replacementToken != null)
|
|
evalTokens = [replacementToken];
|
|
else
|
|
evalTokens = [nextToken];
|
|
}
|
|
}
|
|
finally {
|
|
void withLock([sampler, "sample"], sampler.asyncDispose);
|
|
}
|
|
}
|
|
/** @internal */
|
|
async *_speculativeEvaluate(tokens, metadata, { temperature, minP, topK, topP, seed, grammarEvaluationState, repeatPenalty, tokenBias, evaluationPriority = defaultEvaluationPriority, contextShiftOptions, yieldEogToken = false, tokenPredictor }) {
|
|
this._ensureNotDisposed();
|
|
let evalTokens = tokens.slice();
|
|
if (evalTokens.length === 0)
|
|
return;
|
|
const tokenPredictorOwner = {};
|
|
this._tokenPredictorOwner = tokenPredictorOwner;
|
|
await this._abortTokenPredictor();
|
|
const sampleProbabilities = metadata.probabilities === true;
|
|
const sampleConfidence = metadata.confidence === true;
|
|
let logitsArray = [];
|
|
let logitsStartIndex = evalTokens.length - 1;
|
|
const validatedTokens = [];
|
|
logitsArray[logitsStartIndex] = true;
|
|
const sampler = new LlamaSampler(this.model);
|
|
try {
|
|
while (true) {
|
|
this._ensureNotDisposed();
|
|
const evaluatorLock = await acquireLock([this._lock, "evaluate"]);
|
|
let nextToken;
|
|
const yieldRes = {};
|
|
try {
|
|
if (this._tokenPredictorOwner === tokenPredictorOwner &&
|
|
this._loadedTokenPredictions.length > 0 &&
|
|
evalTokens.length === 1 &&
|
|
evalTokens[0] === this._loadedTokenPredictions[0]?.[0]) {
|
|
const [token, probabilities, confidence] = this._loadedTokenPredictions.shift()[1];
|
|
nextToken = token;
|
|
yieldRes.token = nextToken;
|
|
if (probabilities != null)
|
|
yieldRes.probabilities = reviveTokenProbabilities(probabilities);
|
|
if (confidence != null)
|
|
yieldRes.confidence = confidence;
|
|
const resolvedGrammarEvaluationState = grammarEvaluationState instanceof Function
|
|
? grammarEvaluationState()
|
|
: grammarEvaluationState;
|
|
if (resolvedGrammarEvaluationState != null)
|
|
LlamaSampler._acceptTokenOnGrammarEvaluationState(this._context._llama, resolvedGrammarEvaluationState, nextToken);
|
|
this._unusedTokenPredictions--;
|
|
this._usedTokenPredictions++;
|
|
}
|
|
else if (this._tokenPredictorOwner === tokenPredictorOwner && this._loadedTokenPredictions.length > 0) {
|
|
const deleteStartIndex = Math.max(0, this._nextTokenIndex - this._loadedTokenPredictions.length);
|
|
await this._eraseContextTokenRanges([{ start: deleteStartIndex, end: this._nextTokenIndex }], { canResetTokenPredictor: true, canRemovePredictionTokens: true, skipLock: true });
|
|
this._loadedTokenPredictions.length = 0;
|
|
}
|
|
if (this._resetTokenPredictor) {
|
|
await tokenPredictor.reset({
|
|
stateTokens: [...this._contextTokens, ...evalTokens],
|
|
evaluateOptions: {
|
|
temperature,
|
|
minP,
|
|
topK,
|
|
topP,
|
|
seed,
|
|
grammarEvaluationState: grammarEvaluationState instanceof Function
|
|
? grammarEvaluationState()?.clone()
|
|
: grammarEvaluationState?.clone(),
|
|
repeatPenalty,
|
|
tokenBias,
|
|
evaluationPriority,
|
|
contextShift: contextShiftOptions,
|
|
yieldEogToken: true
|
|
},
|
|
targetSequence: this
|
|
});
|
|
this._resetTokenPredictor = false;
|
|
this._tokenPredictorOwner = tokenPredictorOwner;
|
|
}
|
|
if (nextToken == null) {
|
|
if (this._tokenPredictorOwner === tokenPredictorOwner &&
|
|
// prevent incurring context shifts due to token prediction validations
|
|
this._nextTokenIndex + evalTokens.length < this._context.contextSize) {
|
|
const testGrammarClone = grammarEvaluationState instanceof Function
|
|
? grammarEvaluationState()?.clone()
|
|
: grammarEvaluationState?.clone();
|
|
for (const token of await tokenPredictor.predictTokens()) {
|
|
if (testGrammarClone != null) {
|
|
const canAddToken = LlamaSampler._canBeNextTokenForGrammarEvaluationState(this.model._llama, testGrammarClone, token);
|
|
if (!canAddToken)
|
|
break;
|
|
}
|
|
evalTokens.push(token);
|
|
logitsArray[evalTokens.length - 1] = true;
|
|
// prevent incurring context shifts due to token prediction validations
|
|
if (this._nextTokenIndex + evalTokens.length >= this._context.contextSize)
|
|
break;
|
|
}
|
|
}
|
|
let resolvedGrammarEvaluationState = undefined;
|
|
// Evaluate to get the next token.
|
|
const decodeResult = await this._decodeTokens(evalTokens, logitsArray, evaluationPriority, this._tokenMeter, contextShiftOptions, (batchLogitIndex, tokenIndex) => {
|
|
if (tokenIndex === logitsStartIndex)
|
|
resolvedGrammarEvaluationState = grammarEvaluationState instanceof Function
|
|
? grammarEvaluationState()
|
|
: grammarEvaluationState;
|
|
else if (tokenIndex === logitsStartIndex + 1)
|
|
resolvedGrammarEvaluationState = resolvedGrammarEvaluationState?.clone();
|
|
const samplerConfig = this._resolveSamplerConfig({
|
|
temperature,
|
|
minP,
|
|
topK,
|
|
topP,
|
|
seed,
|
|
grammarEvaluationState: resolvedGrammarEvaluationState,
|
|
repeatPenalty,
|
|
tokenBias
|
|
});
|
|
return withLock([sampler, "sample"], async () => {
|
|
if (sampler.disposed)
|
|
return null;
|
|
sampler.applyConfig(samplerConfig);
|
|
if (sampleProbabilities || sampleConfidence)
|
|
return this._context._ctx.sampleToken(batchLogitIndex, sampler._sampler, sampleProbabilities, sampleConfidence);
|
|
else
|
|
return this._context._ctx.sampleToken(batchLogitIndex, sampler._sampler);
|
|
});
|
|
});
|
|
for (let i = logitsStartIndex; i < evalTokens.length; i++) {
|
|
const item = decodeResult[i];
|
|
const [resultToken, probabilities, confidence] = item instanceof Array
|
|
? item
|
|
: [item];
|
|
if (i === logitsStartIndex) {
|
|
if (resultToken === -1)
|
|
throw new Error("Failed to sample next token");
|
|
if (resultToken == null)
|
|
return;
|
|
nextToken = resultToken;
|
|
yieldRes.token = nextToken;
|
|
if (probabilities != null)
|
|
yieldRes.probabilities = reviveTokenProbabilities(probabilities);
|
|
if (confidence != null)
|
|
yieldRes.confidence = confidence;
|
|
}
|
|
else {
|
|
if (resultToken === -1 || resultToken == null)
|
|
break;
|
|
const lastValidatedTokenOutput = i === logitsStartIndex + 1
|
|
? nextToken
|
|
: validatedTokens.at(-1)?.[1];
|
|
if (lastValidatedTokenOutput != null && lastValidatedTokenOutput === evalTokens[i]) {
|
|
this._loadedTokenPredictions.push([evalTokens[i], [resultToken, probabilities, confidence]]);
|
|
this._validatedTokenPredictions++;
|
|
this._unusedTokenPredictions++;
|
|
}
|
|
else {
|
|
const deleteSize = Math.min(evalTokens.length - i, this.context.contextSize);
|
|
this._refutedTokenPredictions += deleteSize;
|
|
const deleteStartIndex = this._nextTokenIndex - deleteSize;
|
|
tokenPredictor.stop(true);
|
|
await this._eraseContextTokenRanges([{
|
|
start: deleteStartIndex,
|
|
end: this._nextTokenIndex
|
|
}], { canResetTokenPredictor: false, canRemovePredictionTokens: false, skipLock: true });
|
|
break; // the assumption that this token will be generated was wrong
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (nextToken == null)
|
|
throw new Error("Failed to generated next token");
|
|
// the model finished generating text
|
|
if (!yieldEogToken && this._context.model.isEogToken(nextToken))
|
|
break;
|
|
}
|
|
finally {
|
|
evaluatorLock.dispose();
|
|
}
|
|
const replacementToken = yield yieldRes;
|
|
// set the tokens for the next evaluation
|
|
if (replacementToken instanceof Array)
|
|
evalTokens = replacementToken.slice();
|
|
else if (replacementToken != null)
|
|
evalTokens = [replacementToken];
|
|
else
|
|
evalTokens = [nextToken];
|
|
if (this._tokenPredictorOwner === tokenPredictorOwner)
|
|
tokenPredictor.pushTokens(evalTokens);
|
|
logitsArray = [];
|
|
logitsStartIndex = evalTokens.length - 1;
|
|
logitsArray[logitsStartIndex] = true;
|
|
}
|
|
}
|
|
finally {
|
|
void withLock([sampler, "sample"], sampler.asyncDispose);
|
|
if (this._tokenPredictorOwner === tokenPredictorOwner)
|
|
tokenPredictor.stop();
|
|
}
|
|
}
|
|
/** @internal */
|
|
async _abortTokenPredictor(skipClearingPredictionsFromState = false, skipLock = false) {
|
|
this._tokenPredictor?.stop();
|
|
this._resetTokenPredictor = true;
|
|
if (skipClearingPredictionsFromState)
|
|
return;
|
|
if (this._loadedTokenPredictions.length > 0)
|
|
await this._eraseContextTokenRanges([{
|
|
start: this._nextTokenIndex - this._loadedTokenPredictions.length,
|
|
end: this._nextTokenIndex
|
|
}], { canResetTokenPredictor: true, canRemovePredictionTokens: true, skipLock });
|
|
}
|
|
/** @internal */
|
|
_resolveSamplerConfig({ temperature = 0, minP = 0, topK = 40, topP = 0.95, seed, grammarEvaluationState, repeatPenalty, tokenBias }) {
|
|
const repeatPenaltyTokens = repeatPenalty?.punishTokens instanceof Function
|
|
? repeatPenalty.punishTokens()
|
|
: repeatPenalty?.punishTokens;
|
|
const maxPunishTokens = Math.max(repeatPenalty?.maxPunishTokens ?? defaultMaxPunishTokens, repeatPenaltyTokens?.length ?? 0);
|
|
const resolvedGrammarEvaluationState = grammarEvaluationState instanceof Function
|
|
? grammarEvaluationState()
|
|
: grammarEvaluationState;
|
|
if (resolvedGrammarEvaluationState != null && resolvedGrammarEvaluationState._llama !== this.model._llama)
|
|
throw new Error("The LlamaGrammar used by passed to this function was created with a different Llama instance than the one used by this sequence's model. Make sure you use the same Llama instance for both the model and the grammar.");
|
|
const { tokenBiasKeys, tokenBiasValues } = getTokenBiasesForAddon(tokenBias, this.model);
|
|
return removeNullFields({
|
|
temperature,
|
|
minP,
|
|
topK,
|
|
topP,
|
|
seed: Math.max(0, Number.isFinite(seed)
|
|
? Math.floor(seed ?? (Date.now() / 1000))
|
|
: Math.floor(Date.now() / 1000)),
|
|
repeatPenalty: repeatPenalty?.penalty,
|
|
repeatPenaltyMaxTokens: maxPunishTokens,
|
|
repeatPenaltyTokens: repeatPenaltyTokens != null
|
|
? Uint32Array.from(repeatPenaltyTokens)
|
|
: undefined,
|
|
repeatPenaltyPresencePenalty: repeatPenalty?.presencePenalty,
|
|
repeatPenaltyFrequencyPenalty: repeatPenalty?.frequencyPenalty,
|
|
tokenBiasKeys,
|
|
tokenBiasValues,
|
|
grammarEvaluationState: resolvedGrammarEvaluationState?._state
|
|
});
|
|
}
|
|
/**
|
|
* The caller of this function has to wrap it with a lock to ensure this function doesn't run concurrently.
|
|
* @internal
|
|
*/
|
|
async _decodeTokens(tokens, logits, evaluationPriority, tokenMeter, contextShiftOptions, logitDataMapper) {
|
|
this._ensureNotDisposed();
|
|
const tokensLeftToDecode = tokens.slice();
|
|
const tokenLogitsLeftToDecode = logits.slice();
|
|
let currentTokenIndex = 0;
|
|
const res = [];
|
|
const normalizedLogitDataMapper = (batchLogitIndex, contextStateTokenIndex) => {
|
|
return logitDataMapper(batchLogitIndex, currentTokenIndex + (contextStateTokenIndex - this._nextTokenIndex));
|
|
};
|
|
while (tokensLeftToDecode.length > 0) {
|
|
this._ensureNotDisposed();
|
|
let freeSpace = this._context.contextSize - 1 - this._nextTokenIndex;
|
|
if (freeSpace <= 0) {
|
|
await this._freeUpSpaceForTokens(contextShiftOptions);
|
|
freeSpace = this._context.contextSize - 1 - this._nextTokenIndex;
|
|
if (freeSpace <= 0)
|
|
throw new Error("Failed to free up space for new tokens");
|
|
}
|
|
const tokensToDecode = tokensLeftToDecode.splice(0, freeSpace);
|
|
const tokensLogits = tokenLogitsLeftToDecode.slice(0, tokensToDecode.length);
|
|
const generatedLogits = await this._context._decodeTokens({
|
|
sequenceId: this._sequenceId,
|
|
tokens: tokensToDecode,
|
|
firstTokenSequenceIndex: this._nextTokenIndex,
|
|
logits: tokensLogits,
|
|
evaluationPriority,
|
|
tokenMeter
|
|
}, normalizedLogitDataMapper);
|
|
for (const [index, value] of generatedLogits)
|
|
res[currentTokenIndex + (index - this._nextTokenIndex)] = value;
|
|
this._nextTokenIndex += tokensToDecode.length;
|
|
currentTokenIndex += tokensToDecode.length;
|
|
this._contextTokens = this._contextTokens.concat(tokensToDecode);
|
|
}
|
|
return res;
|
|
}
|
|
/** @internal */
|
|
async _freeUpSpaceForTokens(contextShiftOptions) {
|
|
this._ensureNotDisposed();
|
|
const size = Math.min(this._nextTokenIndex, Math.max(1, contextShiftOptions.size instanceof Function
|
|
? await contextShiftOptions.size(this)
|
|
: contextShiftOptions.size));
|
|
this._ensureNotDisposed();
|
|
if (contextShiftOptions.strategy === "eraseBeginning") {
|
|
let eraseStartIndex = 0;
|
|
if (this.model.tokens.bos != null && this._contextTokens[0] === this.model.tokens.bos)
|
|
eraseStartIndex = 1;
|
|
await this._eraseContextTokenRanges([{ start: eraseStartIndex, end: size + eraseStartIndex }], { skipLock: true });
|
|
}
|
|
else {
|
|
const ranges = await contextShiftOptions.strategy({
|
|
sequence: this,
|
|
size
|
|
});
|
|
if (ranges == null)
|
|
throw new Error("Invalid delete ranges");
|
|
await this._eraseContextTokenRanges(ranges, { skipLock: true });
|
|
if (this._nextTokenIndex >= this._context.contextSize - 1)
|
|
await this._eraseContextTokenRanges([{ start: 0, end: size }], { skipLock: true });
|
|
}
|
|
}
|
|
/** @internal */
|
|
_ensureNotDisposed() {
|
|
if (this._disposed)
|
|
throw new DisposedError();
|
|
}
|
|
/**
|
|
* We need this to make it impossible to manually create instances of this class outside the code of this library
|
|
* @internal
|
|
*/
|
|
static _create({ sequenceId, context, tokenMeter, contextShift: { size: contextShiftSize = Math.min(100, Math.ceil(context.contextSize / 2)), strategy: contextShiftStrategy = "eraseBeginning" } = {}, tokenPredictor }) {
|
|
return new LlamaContextSequence({
|
|
sequenceId,
|
|
context,
|
|
tokenMeter,
|
|
contextShift: {
|
|
size: contextShiftSize,
|
|
strategy: contextShiftStrategy
|
|
},
|
|
tokenPredictor
|
|
});
|
|
}
|
|
}
|
|
function getTokenBiasesForAddon(tokenBias, currentModel) {
|
|
if (tokenBias == null)
|
|
return {
|
|
tokenBiasKeys: undefined,
|
|
tokenBiasValues: undefined
|
|
};
|
|
if (tokenBias instanceof Function)
|
|
tokenBias = tokenBias();
|
|
if (tokenBias._tokenizer !== currentModel.tokenizer)
|
|
throw new Error("This TokenBias instance was created with a different model than the one used by this context. " +
|
|
"Make sure you use the model instance of the context sequence for the TokenBias you use it with.");
|
|
const tokenBiasKeys = [];
|
|
const tokenBiasValues = [];
|
|
for (const [token, bias] of tokenBias._biases) {
|
|
tokenBiasKeys.push(token);
|
|
tokenBiasValues.push(bias);
|
|
}
|
|
if (tokenBiasKeys.length === 0 || tokenBiasValues.length === 0) {
|
|
return {
|
|
tokenBiasKeys: undefined,
|
|
tokenBiasValues: undefined
|
|
};
|
|
}
|
|
return {
|
|
tokenBiasKeys: Uint32Array.from(tokenBiasKeys),
|
|
tokenBiasValues: Float32Array.from(tokenBiasValues)
|
|
};
|
|
}
|
|
function reviveTokenProbabilities(probabilities) {
|
|
if (probabilities == null)
|
|
return undefined;
|
|
const res = new Map();
|
|
for (let i = 1; i < probabilities.length; i += 2) {
|
|
const token = probabilities[i - 1];
|
|
const probability = probabilities[i];
|
|
res.set(token, probability);
|
|
}
|
|
return res;
|
|
}
|
|
function disposeContextIfReferenced(contextRef) {
|
|
const context = contextRef.deref();
|
|
if (context != null)
|
|
void context.dispose();
|
|
}
|
|
function disposeContextSequenceIfReferenced(contextRef) {
|
|
const context = contextRef.deref();
|
|
if (context != null)
|
|
context.dispose();
|
|
}
|
|
export function getDefaultContextBatchSize({ contextSize, sequences }) {
|
|
return Math.min(contextSize * sequences, 512);
|
|
}
|
|
export function getDefaultContextSequences() {
|
|
return 1;
|
|
}
|
|
const defaultFallbackContextSize = 4096;
|
|
export function getDefaultModelContextSize({ trainContextSize }) {
|
|
return trainContextSize ?? defaultFallbackContextSize;
|
|
}
|
|
//# sourceMappingURL=LlamaContext.js.map
|