import path from "path"; import { acquireLock, AsyncDisposeAggregator, DisposeAggregator, DisposedError, EventRelay, withLock } from "lifecycle-utils"; import { removeNullFields } from "../../utils/removeNullFields.js"; import { compareTokens } from "../../utils/compareTokens.js"; import { DisposeGuard } from "../../utils/DisposeGuard.js"; import { TokenMeter } from "../TokenMeter.js"; import { UnsupportedError } from "../../utils/UnsupportedError.js"; import { pushAll } from "../../utils/pushAll.js"; import { safeEventCallback } from "../../utils/safeEventCallback.js"; import { GgufArchitectureType } from "../../gguf/types/GgufMetadataTypes.js"; import { resolveBatchItemsPrioritizationStrategy } from "./utils/resolveBatchItemsPrioritizationStrategy.js"; import { LlamaSampler } from "./LlamaSampler.js"; import { padSafeContextSize } from "./utils/padSafeContextSize.js"; const defaultLoraScale = 1; const shrinkRetriesMinContextSize = 4096; const defaultMaxPunishTokens = 64; const defaultFailedCreationRemedy = { retries: 16, autoContextSizeShrink: 0.16 }; const defaultEvaluationPriority = 5; const decodeSyncWorkaround = { vulkanLock: {} }; export class LlamaContext { /** @internal */ _llama; /** @internal */ _ctx; /** @internal */ _onReclaimUnusedSequenceId = new EventRelay(); /** @internal */ _backendContextDisposeGuard; /** @internal */ _model; /** @internal */ _contextSize; /** @internal */ _batchSize; /** @internal */ _flashAttention; /** @internal */ _idealThreads; /** @internal */ _minThreads; /** @internal */ _performanceTracking; /** @internal */ _totalSequences; /** @internal */ _unusedSequenceIds = []; /** @internal */ _batchingOptions; /** @internal */ _swaFullCache = false; /** @internal */ _queuedDecodeSequenceIds = new Set(); /** @internal */ _queuedDecodes = []; /** @internal */ _disposeAggregator = new AsyncDisposeAggregator(); /** @internal */ _modelPreventDisposalHandle; /** @internal */ _loraAdapters = new Set(); /** @internal */ _nextGeneratedSequenceId = 0; /** @internal */ _dispatchDecodeScheduled = false; /** @internal */ _batchDispatchPending = false; /** @internal */ _threadSplitterConsumer; /** @internal */ _freeReservedThreadsTimeout; /** @internal */ _currentDispatchBatchHandle = {}; /** @internal */ _allocatedContextSize; /** @internal */ _disposed = false; onDispose = new EventRelay(); constructor({ _model }, { sequences, contextSize, batchSize, flashAttention = _model.defaultContextFlashAttention, threads, batching: { dispatchSchedule: batchingDispatchSchedule = "nextCycle", itemPrioritizationStrategy: batchingItemsPrioritizationStrategy = "maximumParallelism" } = {}, swaFullCache = _model.defaultContextSwaFullCache, performanceTracking = false, _embeddings, _ranking }) { if (_model.disposed) throw new DisposedError(); const kvUnified = false; this._llama = _model._llama; this._model = _model; this._backendContextDisposeGuard = new DisposeGuard([this._model._backendModelDisposeGuard]); this._modelPreventDisposalHandle = this._model._backendModelDisposeGuard.createPreventDisposalHandle(); this._totalSequences = Math.max(1, Math.floor(sequences)); this._contextSize = kvUnified ? Math.floor(padSafeContextSize(Math.max(2, contextSize) * this._totalSequences, "up") / this._totalSequences) : padSafeContextSize(Math.max(2, contextSize), "up"); this._batchSize = Math.max(batchSize, this._totalSequences); this._flashAttention = flashAttention; this._idealThreads = typeof threads === "number" ? this._llama._threadsSplitter.normalizeThreadsValue(threads) : this._llama._threadsSplitter.normalizeThreadsValue(threads?.ideal ?? (this._llama.maxThreads === 0 ? this._llama.cpuMathCores : this._llama.maxThreads)); this._minThreads = Math.max(1, typeof threads === "number" ? 1 : this._llama._threadsSplitter.normalizeThreadsValue(threads?.min ?? 1)); this._performanceTracking = !!performanceTracking; this._swaFullCache = !!swaFullCache; this._ctx = new this._llama._bindings.AddonContext(this._model._model, removeNullFields({ contextSize: padSafeContextSize(this._contextSize * this._totalSequences, "up"), // each sequence needs its own of cells batchSize: this._batchSize + ((!this._swaFullCache && this.model.fileInsights.swaSize != null && this.model.fileInsights.swaSize > 0) ? 1 // +1 to handle edge cases with SWA KV cache : 0), sequences: this._totalSequences, flashAttention: this._flashAttention, threads: this._idealThreads, embeddings: _embeddings, ranking: _ranking, performanceTracking: this._performanceTracking, swaFullCache: this._swaFullCache })); this._batchingOptions = { dispatchSchedule: batchingDispatchSchedule, itemPrioritizationStrategy: batchingItemsPrioritizationStrategy }; this._reclaimUnusedSequenceId = this._reclaimUnusedSequenceId.bind(this); this._freeReservedThreads = this._freeReservedThreads.bind(this); this._disposeAggregator.add(() => { this._disposed = true; }); this._disposeAggregator.add(this._onReclaimUnusedSequenceId); this._disposeAggregator.add(this.onDispose.dispatchEvent); this._disposeAggregator.add(this.model.onDispose.createListener(disposeContextIfReferenced.bind(null, new WeakRef(this)))); this._disposeAggregator.add(async () => { await this._backendContextDisposeGuard.acquireDisposeLock(); await this._ctx.dispose(); this._modelPreventDisposalHandle.dispose(); }); } async dispose() { if (this._disposed) return; this._disposed = true; await this._disposeAggregator.dispose(); } /** @hidden */ [Symbol.asyncDispose]() { return this.dispose(); } get disposed() { return this._disposed; } get model() { return this._model; } get contextSize() { return this._contextSize; } get batchSize() { return this._batchSize; } get flashAttention() { return this._flashAttention; } /** * The actual size of the state in the memory in bytes. * This value is provided by `llama.cpp` and doesn't include all the memory overhead of the context. */ get stateSize() { this._ensureNotDisposed(); return this._ctx.getStateSize(); } /** The number of threads currently used to evaluate tokens */ get currentThreads() { this._ensureNotDisposed(); return this._ctx.getThreads(); } /** * The number of threads that are preferred to be used to evaluate tokens. * * The actual number of threads used may be lower when other evaluations are running in parallel. */ get idealThreads() { return this._idealThreads; } getAllocatedContextSize() { this._ensureNotDisposed(); if (this._allocatedContextSize == null) this._allocatedContextSize = this._ctx.getContextSize(); return this._allocatedContextSize; } get totalSequences() { return this._totalSequences; } get sequencesLeft() { return this._totalSequences - this._nextGeneratedSequenceId + this._unusedSequenceIds.length; } /** * Before calling this method, make sure to call `sequencesLeft` to check if there are any sequences left. * When there are no sequences left, this method will throw an error. */ getSequence(options = {}) { const { contextShift: { size: contextShiftSize = Math.min(100, Math.ceil(this.contextSize / 2)), strategy: contextShiftStrategy = "eraseBeginning" } = {}, tokenPredictor, _tokenMeter } = options; this._ensureNotDisposed(); const nextSequenceId = this._popSequenceId(); if (nextSequenceId == null) throw new Error("No sequences left"); return LlamaContextSequence._create({ sequenceId: nextSequenceId, context: this, tokenMeter: _tokenMeter, contextShift: { size: contextShiftSize, strategy: contextShiftStrategy }, tokenPredictor }); } dispatchPendingBatch() { this._currentDispatchBatchHandle = {}; this._dispatchDecodeScheduled = false; if (this._batchDispatchPending) return; this._batchDispatchPending = true; void withLock([this, "context"], async () => { this._currentDispatchBatchHandle = {}; this._dispatchDecodeScheduled = false; this._batchDispatchPending = false; let shouldHaveAnotherLoop = this._queuedDecodes.length > 0; const queuedDecodeToMappedLogits = new Map(); const resolvePrioritizationStrategy = () => { try { this._ensureNotDisposed(); return resolveBatchItemsPrioritizationStrategy(this._batchingOptions.itemPrioritizationStrategy); } catch (err) { this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err); } return null; }; const getOrderedQueuedDecodes = (prioritizationStrategy) => { const batchItemToQueuedDecodeMap = new Map(); const batchItemsList = []; for (const queuedDecode of this._queuedDecodes) { const batchItem = { tokens: queuedDecode.tokens, logits: queuedDecode.logits, evaluationPriority: queuedDecode.evaluationPriority }; batchItemToQueuedDecodeMap.set(batchItem, queuedDecode); batchItemsList.push(batchItem); } let prioritizedItems; try { prioritizedItems = prioritizationStrategy({ items: batchItemsList, size: this._batchSize }); } catch (err) { this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err); return null; } return prioritizedItems.map((prioritizedItem) => { const queuedDecode = batchItemToQueuedDecodeMap.get(prioritizedItem.item); if (queuedDecode == null) throw new Error("Received invalid batch item. Make sure you keep the original object reference " + "of the batch item on `item` on `PrioritizedBatchItem` in your custom prioritization strategy"); return { queuedDecode, processAmount: prioritizedItem.processAmount }; }); }; const fitQueuedDecodesToABatch = (queuedDecodes, batchSize) => { const currentBatchItems = []; let currentBatchSize = 0; let batchTokenSlotsLeft = batchSize; for (const { queuedDecode, processAmount } of queuedDecodes) { const resolvedProcessAmount = Math.min(processAmount <= 0 ? 1 : processAmount, queuedDecode.tokens.length, batchTokenSlotsLeft); if (resolvedProcessAmount <= 0) { if (batchTokenSlotsLeft === 0) break; continue; } batchTokenSlotsLeft -= resolvedProcessAmount; currentBatchSize += resolvedProcessAmount; currentBatchItems.push({ queuedDecode, processAmount: resolvedProcessAmount }); } return { currentBatchItems, currentBatchSize }; }; const decodeTokenBatchItems = async (batchItems, currentBatchSize) => { const afterDecodeActions = []; const queuedDecodesToDelete = new Set(); const currentQueuedDecodeItems = new Set(); if (currentBatchSize !== 0) this._ctx.initBatch(currentBatchSize); for (const { queuedDecode, processAmount } of batchItems) { let batchLogitIndexes; const tokensToProcess = queuedDecode.tokens.slice(0, processAmount); const tokenIndexesWithLogitsToProcess = queuedDecode.logits.slice(0, processAmount) .map((logit, index) => (logit ? index : undefined)) .filter((index) => index != undefined); const numberOfOutputTokens = tokenIndexesWithLogitsToProcess.length; TokenMeter.useTokens(queuedDecode.tokenMeter, Math.max(0, tokensToProcess.length - numberOfOutputTokens), "input"); TokenMeter.useTokens(queuedDecode.tokenMeter, numberOfOutputTokens, "output"); try { batchLogitIndexes = this._ctx.addToBatch(queuedDecode.sequenceId, queuedDecode.firstTokenSequenceIndex, Uint32Array.from(tokensToProcess), Uint32Array.from(tokenIndexesWithLogitsToProcess)); } catch (err) { this._dispatchErrorForQueuedDecodesAndDequeue(new Set([queuedDecode]), err); continue; } currentQueuedDecodeItems.add(queuedDecode); if (queuedDecode.tokens.length === processAmount) { queuedDecodesToDelete.add(queuedDecode); afterDecodeActions.push({ queuedDecode, batchLogitIndexes, batchLogitTokenIndexes: tokenIndexesWithLogitsToProcess, firstTokenIndex: queuedDecode.firstTokenSequenceIndex, returnResults: true }); } else { if (batchLogitIndexes.length > 0) afterDecodeActions.push({ queuedDecode, batchLogitIndexes, batchLogitTokenIndexes: tokenIndexesWithLogitsToProcess, firstTokenIndex: queuedDecode.firstTokenSequenceIndex }); queuedDecode.tokens = queuedDecode.tokens.slice(processAmount); queuedDecode.logits = queuedDecode.logits.slice(processAmount); queuedDecode.firstTokenSequenceIndex += processAmount; } } for (let i = 0; i < this._queuedDecodes.length; i++) { const queuedDecode = this._queuedDecodes[i]; if (queuedDecodesToDelete.has(queuedDecode)) { this._queuedDecodes.splice(i, 1); this._queuedDecodeSequenceIds.delete(queuedDecode.sequenceId); i--; } } if (currentBatchSize !== 0) { const allocationResult = this._threadSplitterConsumer?.getAllocationToConsume(); const [threadsToUse, consumerHandle] = allocationResult instanceof Promise ? await allocationResult ?? [] : allocationResult ?? []; try { if (threadsToUse != null) this._ctx.setThreads(threadsToUse); await this._ctx.decodeBatch(); consumerHandle?.dispose(); } catch (err) { consumerHandle?.dispose(); this._dispatchErrorForQueuedDecodesAndDequeue(currentQueuedDecodeItems, err); return; } } function finishAfterDecodeAction(action, mappedLogitValues) { if (mappedLogitValues != null && mappedLogitValues.length > 0) { if (queuedDecodeToMappedLogits.has(action.queuedDecode)) pushAll(queuedDecodeToMappedLogits.get(action.queuedDecode), mappedLogitValues); else queuedDecodeToMappedLogits.set(action.queuedDecode, mappedLogitValues); } if (action.returnResults != null) { const [accept] = action.queuedDecode.response; const mappedLogits = queuedDecodeToMappedLogits.get(action.queuedDecode) ?? []; queuedDecodeToMappedLogits.delete(action.queuedDecode); accept(mappedLogits); } } const afterDecodeActionResults = afterDecodeActions.map((action) => { if (action.batchLogitIndexes.length === 0) { finishAfterDecodeAction(action); return undefined; } const mappedLogitValues = []; let promiseChain = undefined; const batchLogitIndexes = action.batchLogitIndexes; const batchLogitTokenIndexes = action.batchLogitTokenIndexes; for (let i = 0; i < batchLogitIndexes.length; i++) { const tokenIndex = batchLogitTokenIndexes[i]; const mappedValue = promiseChain != null ? promiseChain .then(() => action.queuedDecode.logitDataMapper(batchLogitIndexes[i], tokenIndex + action.firstTokenIndex)) : action.queuedDecode.logitDataMapper(batchLogitIndexes[i], tokenIndex + action.firstTokenIndex); if (mappedValue instanceof Promise) { promiseChain = mappedValue; mappedLogitValues.push(mappedValue .then((value) => [tokenIndex + action.firstTokenIndex, value])); } else mappedLogitValues.push([tokenIndex + action.firstTokenIndex, mappedValue]); } if (promiseChain != null) return Promise.all(mappedLogitValues) .then((resolvedMappedLogitValues) => finishAfterDecodeAction(action, resolvedMappedLogitValues)); finishAfterDecodeAction(action, mappedLogitValues); return undefined; }); await Promise.all(afterDecodeActionResults); }; const prioritizationStrategy = resolvePrioritizationStrategy(); if (prioritizationStrategy == null) return; // all queued items are rejected and dequeued when we get here this._reserveThreads(); try { while (shouldHaveAnotherLoop) { const orderedQueuedDecodes = getOrderedQueuedDecodes(prioritizationStrategy); if (orderedQueuedDecodes == null) return; // all queued items are rejected and dequeued when we get here const { currentBatchItems, currentBatchSize } = fitQueuedDecodesToABatch(orderedQueuedDecodes, this._batchSize); let preventDisposalHandle; try { preventDisposalHandle = this._backendContextDisposeGuard.createPreventDisposalHandle(); } catch (err) { this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err); return; } let decodeLock; // this is a workaround to prevent Vulkan from crashing the process when decoding on multiple contexts in parallel if (this._llama.gpu === "vulkan") decodeLock = await acquireLock([decodeSyncWorkaround.vulkanLock, "decode"]); try { await decodeTokenBatchItems(currentBatchItems, currentBatchSize); shouldHaveAnotherLoop = this._queuedDecodes.length > 0; } finally { decodeLock?.dispose(); preventDisposalHandle.dispose(); } } } finally { this._scheduleToFreeReservedThreads(); } }); } /** * Print the timings of token evaluation since that last print for this context. * * Requires the `performanceTracking` option to be enabled. * * > **Note:** it prints on the `LlamaLogLevel.info` level, so if you set the level of your `Llama` instance higher than that, * it won't print anything. */ async printTimings() { this._ensureNotDisposed(); if (!this._performanceTracking) throw new UnsupportedError("Performance tracking is not enabled"); this._ctx.printTimings(); await new Promise((accept) => setTimeout(accept, 0)); // wait for the logs to finish printing } /** @internal */ async _decodeTokens({ sequenceId, firstTokenSequenceIndex, tokens, logits, evaluationPriority = defaultEvaluationPriority, tokenMeter }, logitDataMapper) { return await new Promise((accept, reject) => { this._queuedDecodes.push({ sequenceId, tokens, logits, firstTokenSequenceIndex, evaluationPriority, tokenMeter, response: [accept, reject], logitDataMapper }); this._queuedDecodeSequenceIds.add(sequenceId); this._scheduleDecode(); }); } /** @internal */ _reclaimUnusedSequenceId(sequenceId) { if (this._disposed) return; void withLock([this, "context"], async () => { if (this._disposed) return; this._ctx.disposeSequence(sequenceId); this._unusedSequenceIds.push(sequenceId); this._onReclaimUnusedSequenceId.dispatchEvent(); }); } /** @internal */ _popSequenceId() { if (this._unusedSequenceIds.length > 0) return this._unusedSequenceIds.shift(); if (this._nextGeneratedSequenceId < this._totalSequences) { const sequenceId = this._nextGeneratedSequenceId; this._nextGeneratedSequenceId++; return sequenceId; } return null; } /** @internal */ _scheduleDecode() { if (this._dispatchDecodeScheduled || this._batchDispatchPending) return; this._dispatchDecodeScheduled = true; const currentPendingBatchHandle = this._currentDispatchBatchHandle; const dispatch = () => { if (this._currentDispatchBatchHandle !== currentPendingBatchHandle) return; this.dispatchPendingBatch(); }; const dispatchSchedule = this._batchingOptions.dispatchSchedule; if (this._queuedDecodeSequenceIds.size === this._totalSequences) dispatch(); if (dispatchSchedule === "nextCycle") { if (typeof setImmediate === "function") setImmediate(dispatch); else setTimeout(dispatch, 0); } else if (typeof dispatchSchedule === "function") dispatchSchedule(dispatch); else { if (typeof setImmediate === "function") setImmediate(dispatch); else setTimeout(dispatch, 0); } } /** @internal */ _dispatchErrorForQueuedDecodesAndDequeue(queuedDecodes, err) { for (const pendingDecode of queuedDecodes) { const [, reject] = pendingDecode.response; reject(err); } for (let i = 0; i < this._queuedDecodes.length; i++) { const item = this._queuedDecodes[i]; if (queuedDecodes.has(item)) { this._queuedDecodes.splice(i, 1); this._queuedDecodeSequenceIds.delete(item.sequenceId); i--; } } } /** @internal */ _ensureNotDisposed() { if (this._disposed) throw new DisposedError(); } /** @internal */ async _setLora({ filePath, scale }) { const lora = await this._model._getOrLoadLora(filePath); this._ctx.setLora(lora, scale ?? defaultLoraScale); if (!this._loraAdapters.has(lora)) { this._loraAdapters.add(lora); lora.usages++; } } /** @internal */ _reserveThreads() { clearTimeout(this._freeReservedThreadsTimeout); delete this._freeReservedThreadsTimeout; if (this._threadSplitterConsumer != null) return; this._threadSplitterConsumer = this._llama._threadsSplitter.createConsumer(this._idealThreads, this._minThreads); } /** @internal */ _freeReservedThreads() { clearTimeout(this._freeReservedThreadsTimeout); delete this._freeReservedThreadsTimeout; if (this._threadSplitterConsumer == null) return; this._threadSplitterConsumer.dispose(); delete this._threadSplitterConsumer; } /** @internal */ _scheduleToFreeReservedThreads() { if (this._threadSplitterConsumer == null) return; clearTimeout(this._freeReservedThreadsTimeout); this._freeReservedThreadsTimeout = setTimeout(this._freeReservedThreads, 0); } /** @internal */ static async _create(options, { _model }) { const sequences = options.sequences ?? getDefaultContextSequences(); const flashAttention = _model.flashAttentionSupported ? Boolean(options.flashAttention ?? _model.defaultContextFlashAttention) : false; const swaFullCache = options.swaFullCache ?? _model.defaultContextSwaFullCache; const loraOptions = typeof options.lora === "string" ? { adapters: [{ filePath: options.lora }] } : options.lora; let failedCreationRetries = options.failedCreationRemedy === false ? 0 : Math.max(0, options.failedCreationRemedy?.retries ?? defaultFailedCreationRemedy.retries); const failedCreationAutoContextSizeShrink = options.failedCreationRemedy === false ? 0 : options.failedCreationRemedy?.autoContextSizeShrink ?? defaultFailedCreationRemedy.autoContextSizeShrink; let contextSize = await _model.fileInsights.configurationResolver.resolveContextContextSize(options.contextSize, { batchSize: options.batchSize, sequences: sequences, modelGpuLayers: _model.gpuLayers, modelTrainContextSize: _model.trainContextSize, flashAttention, swaFullCache, getVramState: () => _model._llama._vramOrchestrator.getMemoryState(), llamaGpu: _model._llama.gpu, ignoreMemorySafetyChecks: options.ignoreMemorySafetyChecks, isEmbeddingContext: options._embeddings }); const minContextSize = options.contextSize === "auto" ? shrinkRetriesMinContextSize : (typeof options.contextSize === "object" && typeof options.contextSize.min === "number") ? options.contextSize.min : typeof options.contextSize === "number" ? options.contextSize : shrinkRetriesMinContextSize; const { createSignal } = options; async function createContext(contextSize) { const batchSize = options.batchSize ?? getDefaultContextBatchSize({ contextSize, sequences }); const resourceRequirementsEstimation = _model.fileInsights.estimateContextResourceRequirements({ contextSize, sequences, isEmbeddingContext: options._embeddings, modelGpuLayers: _model.gpuLayers, batchSize, flashAttention, swaFullCache }); const context = new LlamaContext({ _model }, { ...options, contextSize, batchSize, sequences, flashAttention, swaFullCache }); const contextCreationVramReservation = options.ignoreMemorySafetyChecks ? null : _model._llama._vramOrchestrator.reserveMemory(resourceRequirementsEstimation.gpuVram); const contextCreationRamReservation = options.ignoreMemorySafetyChecks ? null : _model._llama._vramOrchestrator.reserveMemory(resourceRequirementsEstimation.cpuRam); try { if (createSignal?.aborted) throw createSignal.reason; const contextLoaded = await context._ctx.init(); if (createSignal?.aborted) { if (contextLoaded) await context._ctx.dispose(); throw createSignal.reason; } else if (!contextLoaded) throw new Error("Failed to create context"); contextCreationVramReservation?.dispose?.(); contextCreationRamReservation?.dispose?.(); if (loraOptions != null && loraOptions.adapters.length > 0) { let loadedAdapters = 0; for (const adapter of loraOptions.adapters) { try { await context._setLora({ filePath: adapter.filePath, scale: adapter.scale }); loadedAdapters++; try { loraOptions.onLoadProgress?.(loadedAdapters / loraOptions.adapters.length); } catch (err) { console.error(err); } } catch (err) { await context.dispose(); throw err; } if (createSignal?.aborted) { await context.dispose(); throw createSignal.reason; } } } else if (loraOptions?.onLoadProgress != null) { try { loraOptions.onLoadProgress(1); } catch (err) { console.error(err); } } return context; } finally { contextCreationVramReservation?.dispose?.(); contextCreationRamReservation?.dispose?.(); } } while (failedCreationRetries >= 0) { try { return await createContext(contextSize); } catch (err) { if (failedCreationRetries === 0 || (createSignal?.aborted && err === createSignal.reason)) throw err; failedCreationRetries--; let newContextSize = typeof failedCreationAutoContextSizeShrink === "number" ? Math.floor(contextSize * (1 - failedCreationAutoContextSizeShrink)) : Math.floor(failedCreationAutoContextSizeShrink(contextSize)); if (!Number.isFinite(newContextSize)) throw err; if (newContextSize < minContextSize) newContextSize = minContextSize; if (newContextSize >= contextSize) throw err; contextSize = newContextSize; } } throw new Error("Failed to create context"); } } export class LlamaContextSequence { /** @internal */ _sequenceId; /** @internal */ _gcRegistry; /** @internal */ _context; /** @internal */ _contextShift; /** @internal */ _tokenPredictor; /** @internal */ _tokenMeter; /** @internal */ _disposeAggregator = new DisposeAggregator(); /** @internal */ _lock = {}; /** @internal */ _resetTokenPredictor = false; /** @internal */ _tokenPredictorOwner = {}; /** @internal */ _contextTokens = []; /** @internal */ _nextTokenIndex = 0; /** @internal */ _loadedTokenPredictions = []; /** @internal */ _usedTokenPredictions = 0; /** @internal */ _unusedTokenPredictions = 0; /** @internal */ _validatedTokenPredictions = 0; /** @internal */ _refutedTokenPredictions = 0; /** @internal */ _disposed = false; onDispose = new EventRelay(); constructor({ sequenceId, context, tokenMeter, contextShift, tokenPredictor }) { this._sequenceId = sequenceId; this._context = context; this._tokenMeter = tokenMeter ?? new TokenMeter(); this._contextShift = contextShift; this._tokenPredictor = tokenPredictor; this._gcRegistry = new FinalizationRegistry(this._context._reclaimUnusedSequenceId); this._gcRegistry.register(this, sequenceId); this._disposeAggregator.add(() => this._gcRegistry.unregister(this)); this._disposeAggregator.add(this.onDispose.dispatchEvent); this._disposeAggregator.add(this.model.onDispose.createListener(disposeContextSequenceIfReferenced.bind(null, new WeakRef(this)))); this._disposeAggregator.add(() => { this._context._reclaimUnusedSequenceId(this._sequenceId); }); if (this._tokenPredictor != null) this._disposeAggregator.add(this._tokenPredictor); } dispose() { if (this._disposed) return; this._disposeAggregator.dispose(); this._contextTokens.length = 0; this._disposed = true; } /** @hidden */ [Symbol.dispose]() { return this.dispose(); } get disposed() { return this._disposed; } get context() { return this._context; } get model() { return this._context.model; } /** The maximum number of tokens that the sequence state can hold */ get contextSize() { return this._context.contextSize; } /** The index where the next evaluated token will be placed in the context */ get nextTokenIndex() { return this._nextTokenIndex - this._loadedTokenPredictions.length; } /** The current context state tokens */ get contextTokens() { if (this._loadedTokenPredictions.length === 0) return this._contextTokens.slice(); return this._contextTokens.slice(0, -this._loadedTokenPredictions.length); } get tokenMeter() { return this._tokenMeter; } /** * The token predictor used when creating this sequence. */ get tokenPredictor() { return this._tokenPredictor; } /** * Get the index of the first token in the KV cache. * * If you remove any tokens from the state that come before this index, * no cached prefix tokens evaluation state will be used for the next evaluation. * * For example, if `stateCellsStartIndex` is `10` and you remove the range `{start: 11, end: 16}` * then the cached state for range `0-10` will be used in the next evaluation, * but if you remove the range `{start: 10, end: 16}` (or `{start: 9, end: 16}`) then the cached state will not be used at all * and will be re-evaluated in the next evaluation. * * This index can be greater than `0` only when SWA (Sliding Window Attention) is used (only on supported models). * * When SWA is used, this index will usually be `Math.max(-1, .nextTokenIndex - .model.fileInsights.swaSize)` or larger. * * When the KV cache is empty, this index will be `-1`. * * You can disable SWA by setting the `swaFullCache` option to `true` when creating a context. */ get stateCellsStartIndex() { this._ensureNotDisposed(); return this._context._ctx.getSequenceKvCacheMinPosition(this._sequenceId); } /** * Statistics of token predictions using the sequence's `tokenPredictor`. * * The statistics change only when token prediction is used in this sequence. * * `validated` + `refuted` = total number of evaluated predictions. * * Prefer using `validated` and `refuted` to evaluate the effectiveness of token prediction. */ get tokenPredictions() { return { used: this._usedTokenPredictions, unused: this._unusedTokenPredictions, validated: this._validatedTokenPredictions, refuted: this._refutedTokenPredictions }; } get isLoadedToMemory() { return !this._disposed; } compareContextTokens(tokens) { for (let i = 0; i < this._contextTokens.length - this._loadedTokenPredictions.length; i++) { if (compareTokens(this._contextTokens[i], tokens[i])) continue; return { firstDifferentIndex: i }; } return { firstDifferentIndex: this._contextTokens.length - this._loadedTokenPredictions.length }; } /** * Erase parts of the context state to align it with the given tokens. * * If the given tokens do not align with the current context state, the context state will be erased to align with the given tokens. * * To find the first different token index between the context state and the given tokens, access the `nextTokenIndex` property. * * If `allowShift` is `true` (the default), shifting tokens may happen to align the context state with the given tokens, * which incurs token evaluation of the shifted tokens. */ async adaptStateToTokens(tokens, allowShift = true) { const modelSupportsShifting = !this.model.fileInsights.isRecurrent && this.model.fileInfo.metadata?.general?.architecture !== GgufArchitectureType.deepseek2; if (!modelSupportsShifting || !allowShift) { const { firstDifferentIndex } = this.compareContextTokens(tokens); if (firstDifferentIndex < this.nextTokenIndex) await this._eraseContextTokenRanges([{ start: firstDifferentIndex, end: this._nextTokenIndex }]); return; } const eraseRanges = []; let tokensIndex = 0; let differentTokenIndex = undefined; for (let i = 0; i < this._contextTokens.length - this._loadedTokenPredictions.length && tokensIndex < tokens.length; i++) { if (compareTokens(this._contextTokens[i], tokens[tokensIndex])) { if (differentTokenIndex != null) { eraseRanges.push({ start: differentTokenIndex, end: i }); differentTokenIndex = undefined; } tokensIndex++; continue; } if (differentTokenIndex == null) differentTokenIndex = i; } if (differentTokenIndex != null) eraseRanges.push({ start: differentTokenIndex, end: this._nextTokenIndex }); if (eraseRanges.length > 0) await this._eraseContextTokenRanges(eraseRanges); } /** * Clear the history of the sequence. */ async clearHistory() { this._ensureNotDisposed(); await this._eraseContextTokenRanges([{ start: 0, end: this._nextTokenIndex }]); } /** * Erase context tokens in the provided ranges to free up space for new tokens to be generated. * The start of each range is inclusive, and the end of each range is exclusive. * For example, the range `{start: 0, end: 1}` will remove the token at the `0` index only. */ eraseContextTokenRanges(ranges) { return this._eraseContextTokenRanges(ranges); } /** @internal */ async _eraseContextTokenRanges(ranges, { canResetTokenPredictor = true, canRemovePredictionTokens = true, skipLock = false } = {}) { this._ensureNotDisposed(); let awaitPromise; await withLock([this._context, "context"], async () => { this._ensureNotDisposed(); if (ranges.length === 0) return; // if the deletion fails, we'll have to dispose the sequence and fill it up again let deletionSuccessful = true; const resolvedRanges = ranges .map(({ start, end }) => { if (start === end) return null; if (start > end) [start, end] = [end, start]; if (end > this._nextTokenIndex) end = this._nextTokenIndex; if (start >= this._nextTokenIndex) return null; return { start, end }; }) .filter((range) => range != null) .sort((a, b) => a.start - b.start) .reduce((ranges, range) => { if (ranges.length === 0) return [range]; const lastRange = ranges[ranges.length - 1]; if (lastRange.end >= range.start) { lastRange.end = Math.max(lastRange.end, range.end); return ranges; } ranges.push(range); return ranges; }, []); const minKvCachePosition = (this._contextTokens.length === 0 && this._loadedTokenPredictions.length === 0) ? 0 : Math.max(0, this._context._ctx.getSequenceKvCacheMinPosition(this._sequenceId)); if (resolvedRanges[0] != null && resolvedRanges[0].start <= minKvCachePosition) // we have to drop the cache and reevaluate the sequence due to missing KV cache deletionSuccessful = false; const tokenPredictionsToRemove = (resolvedRanges.length > 0 && canRemovePredictionTokens) ? this._loadedTokenPredictions.length : 0; if (tokenPredictionsToRemove > 0) { const startDeleteIndex = this._nextTokenIndex - this._loadedTokenPredictions.length; const lastDeleteRange = resolvedRanges[resolvedRanges.length - 1]; if (lastDeleteRange.end >= startDeleteIndex) lastDeleteRange.end = this._nextTokenIndex; else resolvedRanges.push({ start: startDeleteIndex, end: this._nextTokenIndex }); if (canResetTokenPredictor) await this._abortTokenPredictor(true); } let removedTokens = 0; let lastDeleteRangeEndPos = null; for (const range of resolvedRanges) { this._contextTokens.splice(range.start - removedTokens, range.end - range.start); if (deletionSuccessful) deletionSuccessful &&= this._context._ctx.removeTokenCellsFromSequence(this._sequenceId, range.start, range.end); if (deletionSuccessful && lastDeleteRangeEndPos != null && removedTokens > 0 && lastDeleteRangeEndPos !== range.start) { this._context._ctx.shiftSequenceTokenCells(this._sequenceId, lastDeleteRangeEndPos, range.start, -removedTokens); const shiftedTokens = range.start - lastDeleteRangeEndPos; this._tokenMeter.useTokens(shiftedTokens, "input"); } removedTokens += range.end - range.start; lastDeleteRangeEndPos = range.end; } if (tokenPredictionsToRemove > 0) this._loadedTokenPredictions.splice(0, tokenPredictionsToRemove); if (deletionSuccessful && lastDeleteRangeEndPos != null && removedTokens > 0 && lastDeleteRangeEndPos !== this._nextTokenIndex) { this._context._ctx.shiftSequenceTokenCells(this._sequenceId, lastDeleteRangeEndPos, this._nextTokenIndex, -removedTokens); const shiftedTokens = this._nextTokenIndex - lastDeleteRangeEndPos; this._tokenMeter.useTokens(shiftedTokens, "input"); } this._nextTokenIndex -= removedTokens; if (canResetTokenPredictor && removedTokens > 0) await this._abortTokenPredictor(true); if (deletionSuccessful) return; const newSequenceTokens = this._contextTokens.slice(); this._nextTokenIndex = 0; this._context._ctx.disposeSequence(this._sequenceId); // wait for the evaluation outside the "context" lock to avoid deadlocks awaitPromise = this.evaluateWithoutGeneratingNewTokens(newSequenceTokens, { _skipLock: skipLock }); }); if (awaitPromise != null) await awaitPromise; } /** * Evaluate the provided tokens into the context sequence, and continue generating new tokens on iterator iterations. * * This method uses the token predictor (when provided) to generate new tokens faster. */ async *evaluate(tokens, options = {}) { const iterator = this.evaluateWithMetadata(tokens, {}, options); let iterateInput = undefined; try { while (true) { const { value, done } = await iterator.next(iterateInput); if (done) return; iterateInput = yield value.token; } } finally { await iterator.return(); } } /** * Like {@link evaluate `.evaluate(...)`}, but with additional metadata for each generated token. * * Configure the additional metadata options to choose which metadata to include. */ evaluateWithMetadata(tokens, metadata, options = {}) { const { temperature = 0, minP = 0, topK = 40, topP = 0.95, seed, grammarEvaluationState, repeatPenalty, tokenBias, evaluationPriority = defaultEvaluationPriority, contextShift: { size: contextShiftSize = this._contextShift.size, strategy: contextShiftStrategy = this._contextShift.strategy } = {}, yieldEogToken = false, _noSampling = false } = options; if (this._tokenPredictor != null && !_noSampling && tokens.length > 0) return this._speculativeEvaluate(tokens, metadata, { temperature, minP, topK, topP, seed, grammarEvaluationState, repeatPenalty, tokenBias, evaluationPriority, contextShiftOptions: { size: contextShiftSize, strategy: contextShiftStrategy }, yieldEogToken, tokenPredictor: this._tokenPredictor }); return this._evaluate(tokens, metadata, { temperature, minP, topK, topP, seed, grammarEvaluationState, repeatPenalty, tokenBias, evaluationPriority, contextShiftOptions: { size: contextShiftSize, strategy: contextShiftStrategy }, yieldEogToken, _noSampling }); } /** * Evaluate the provided tokens into the context sequence without generating new tokens. */ async evaluateWithoutGeneratingNewTokens(tokens, options = {}) { const { evaluationPriority = defaultEvaluationPriority, contextShift: { size: contextShiftSize = this._contextShift.size, strategy: contextShiftStrategy = this._contextShift.strategy } = {}, _skipLock = false } = options; const iterator = this._evaluate(tokens, {}, { generateNewTokens: false, evaluationPriority, contextShiftOptions: { size: contextShiftSize, strategy: contextShiftStrategy }, _skipLock }); const predictorAlignmentPromise = this.tokenPredictor == null ? undefined : this._tokenPredictor?.reset({ stateTokens: [...this._contextTokens, ...tokens], evaluateOptions: { evaluationPriority, contextShift: { size: contextShiftSize, strategy: contextShiftStrategy } }, targetSequence: this }); if (predictorAlignmentPromise != null) { this._tokenPredictorOwner = {}; this._resetTokenPredictor = false; } // eslint-disable-next-line @typescript-eslint/no-unused-vars for await (const token of iterator) { // Array.from doesn't work with async generators, so we have to iterate over the generator } await iterator.return(); if (predictorAlignmentPromise != null) await predictorAlignmentPromise; } /** * Evaluate the provided tokens into the context sequence with custom options for each token. * * This method allows for more precise control of the generation process. * * A next token will be generated for a given token only if any of the `generateNext` options for it are used. * * To generate more tokens after this method finishes, * use it again with token(s) you selected to add to the context from the previous evaluation. * * This method doesn't use the token predictor (when provided) since it cannot predict which tokens are actually needed. * Use the `evaluate` method when you need to use token prediction. * @returns An array where for each token in the input array, there can be an output item at the same index in the output array. * For indexes that have no output, there won't be any value at the corresponding index in the output array. * * It's recommended to iterate from `0` up to the length of the input array to check the results in the output array. */ async controlledEvaluate(input, options) { const { evaluationPriority = defaultEvaluationPriority, contextShift: { size: contextShiftSize = this._contextShift.size, strategy: contextShiftStrategy = this._contextShift.strategy } = {} } = options ?? {}; const contextShiftOptions = { size: contextShiftSize, strategy: contextShiftStrategy }; this._ensureNotDisposed(); if (input.length === 0) return []; await this._abortTokenPredictor(); const sampler = new LlamaSampler(this.model); const onTokenResult = safeEventCallback(options?.onTokenResult); const logitsArray = []; const resolvedTokens = input.map((item, index) => { if (item instanceof Array) { const [token, options] = item; const generateNext = options?.generateNext ?? {}; if (generateNext.probabilities === true || generateNext.confidence === true || generateNext.token === true) logitsArray[index] = true; return token; } return item; }); const evaluatorLock = await acquireLock([this._lock, "evaluate"]); try { return await this._decodeTokens(resolvedTokens, logitsArray, evaluationPriority, this._tokenMeter, contextShiftOptions, async (batchLogitIndex, tokenIndex) => { const inputToken = input[tokenIndex]; const inputOptions = inputToken instanceof Array ? (inputToken[1] ?? {}) : {}; const generateNext = inputOptions.generateNext; if (generateNext == null || ((generateNext.probabilities == null || !generateNext.probabilities) && (generateNext.token == null || !generateNext.token) && (generateNext.confidence == null || !generateNext.confidence))) return undefined; const sampleOptions = generateNext.options ?? {}; const samplerConfig = this._resolveSamplerConfig({ temperature: sampleOptions.temperature, minP: sampleOptions.minP, topK: sampleOptions.topK, topP: sampleOptions.topP, seed: sampleOptions.seed, repeatPenalty: sampleOptions.repeatPenalty, tokenBias: sampleOptions.tokenBias }); return await withLock([sampler, "sample"], async () => { if (sampler.disposed) return undefined; sampler.applyConfig(samplerConfig); const [token, probabilities, confidence] = await this._context._ctx.sampleToken(batchLogitIndex, sampler._sampler, !!generateNext.probabilities, !!generateNext.confidence); const output = { next: {} }; if (generateNext.token) output.next.token = token === -1 ? null : (token ?? null); if (confidence != null) output.next.confidence = confidence; if (probabilities != null) output.next.probabilities = reviveTokenProbabilities(probabilities); onTokenResult?.(tokenIndex, output); return output; }); }); } finally { evaluatorLock.dispose(); void withLock([sampler, "sample"], sampler.asyncDispose); } } /* eslint-disable @stylistic/max-len */ /** * Save the current context sequence evaluation state to a file. * @see [Saving and restoring a context sequence evaluation state](https://node-llama-cpp.withcat.ai/guide/chat-session#save-and-restore-with-context-sequence-state) */ async saveStateToFile(filePath) { /* eslint-enable @stylistic/max-len */ this._ensureNotDisposed(); const resolvedPath = path.resolve(process.cwd(), filePath); const evaluatorLock = await acquireLock([this._lock, "evaluate"]); const contextLock = await acquireLock([this._context, "context"]); try { this._ensureNotDisposed(); const fileSize = await this._context._ctx.saveSequenceStateToFile(resolvedPath, this._sequenceId, Uint32Array.from(this.contextTokens)); return { fileSize }; } finally { contextLock.dispose(); evaluatorLock.dispose(); } } /* eslint-disable @stylistic/max-len */ /** * Load a context sequence evaluation state from a file. * * Trying to load a state file with a longer context size than the current sequence's context size will fail and throw an error. * * You must ensure that the file was created from the exact same model, otherwise, using this function may crash the process. * @see [Saving and restoring a context sequence evaluation state](https://node-llama-cpp.withcat.ai/guide/chat-session#save-and-restore-with-context-sequence-state) */ async loadStateFromFile(filePath, acceptRisk) { /* eslint-enable @stylistic/max-len */ if (!acceptRisk.acceptRisk) throw new Error("The `acceptRisk` option must be set to `true` to use this feature"); this._ensureNotDisposed(); const resolvedPath = path.resolve(process.cwd(), filePath); const evaluatorLock = await acquireLock([this._lock, "evaluate"]); const contextLock = await acquireLock([this._context, "context"]); try { this._ensureNotDisposed(); this._tokenPredictorOwner = {}; await this._abortTokenPredictor(true); this._ensureNotDisposed(); this._loadedTokenPredictions.length = 0; this._nextTokenIndex = 0; this._contextTokens = []; const tokens = Array.from(await this._context._ctx.loadSequenceStateFromFile(resolvedPath, this._sequenceId, this.contextSize)); if (tokens.length > this.contextSize) { this._context._ctx.disposeSequence(this._sequenceId); throw new Error("The given state file is too large for the current context size"); } this._contextTokens = tokens; this._nextTokenIndex = tokens.length; this._loadedTokenPredictions.length = 0; } finally { contextLock.dispose(); evaluatorLock.dispose(); } } /** @internal */ async *_evaluate(tokens, metadata, { temperature, minP, topK, topP, seed, grammarEvaluationState, repeatPenalty, tokenBias, evaluationPriority = defaultEvaluationPriority, generateNewTokens = true, contextShiftOptions, yieldEogToken = false, _noSampling = false, _skipLock = false }) { this._ensureNotDisposed(); let evalTokens = tokens; if (evalTokens.length === 0) return; await this._abortTokenPredictor(false, true); const sampleProbabilities = metadata.probabilities === true; const sampleConfidence = metadata.confidence === true; const sampler = new LlamaSampler(this.model); try { while (true) { this._ensureNotDisposed(); const evaluatorLock = _skipLock ? undefined : await acquireLock([this._lock, "evaluate"]); let nextToken; const yieldRes = {}; try { const logitsArray = []; if (generateNewTokens) logitsArray[evalTokens.length - 1] = true; // Evaluate to get the next token. const decodeResult = await this._decodeTokens(evalTokens, logitsArray, evaluationPriority, this._tokenMeter, contextShiftOptions, (batchLogitIndex) => { if (_noSampling) return null; const samplerConfig = this._resolveSamplerConfig({ temperature, minP, topK, topP, seed, grammarEvaluationState, repeatPenalty, tokenBias }); return withLock([sampler, "sample"], async () => { if (sampler.disposed) return null; sampler.applyConfig(samplerConfig); if (sampleProbabilities || sampleConfidence) return this._context._ctx.sampleToken(batchLogitIndex, sampler._sampler, sampleProbabilities, sampleConfidence); else return this._context._ctx.sampleToken(batchLogitIndex, sampler._sampler); }); }); const lastDecodeResult = decodeResult[evalTokens.length - 1]; if (lastDecodeResult instanceof Array) { const [token, probabilities, confidence] = lastDecodeResult; nextToken = token; if (probabilities != null) yieldRes.probabilities = reviveTokenProbabilities(probabilities); if (confidence != null) yieldRes.confidence = confidence; } else nextToken = lastDecodeResult; if (nextToken === -1) throw new Error("Failed to sample next token"); if (nextToken == null) return; // the model finished generating text if (!yieldEogToken && this._context.model.isEogToken(nextToken)) break; } finally { evaluatorLock?.dispose(); } yieldRes.token = nextToken; const replacementToken = yield yieldRes; // set the tokens for the next evaluation if (replacementToken instanceof Array) evalTokens = replacementToken.slice(); else if (replacementToken != null) evalTokens = [replacementToken]; else evalTokens = [nextToken]; } } finally { void withLock([sampler, "sample"], sampler.asyncDispose); } } /** @internal */ async *_speculativeEvaluate(tokens, metadata, { temperature, minP, topK, topP, seed, grammarEvaluationState, repeatPenalty, tokenBias, evaluationPriority = defaultEvaluationPriority, contextShiftOptions, yieldEogToken = false, tokenPredictor }) { this._ensureNotDisposed(); let evalTokens = tokens.slice(); if (evalTokens.length === 0) return; const tokenPredictorOwner = {}; this._tokenPredictorOwner = tokenPredictorOwner; await this._abortTokenPredictor(); const sampleProbabilities = metadata.probabilities === true; const sampleConfidence = metadata.confidence === true; let logitsArray = []; let logitsStartIndex = evalTokens.length - 1; const validatedTokens = []; logitsArray[logitsStartIndex] = true; const sampler = new LlamaSampler(this.model); try { while (true) { this._ensureNotDisposed(); const evaluatorLock = await acquireLock([this._lock, "evaluate"]); let nextToken; const yieldRes = {}; try { if (this._tokenPredictorOwner === tokenPredictorOwner && this._loadedTokenPredictions.length > 0 && evalTokens.length === 1 && evalTokens[0] === this._loadedTokenPredictions[0]?.[0]) { const [token, probabilities, confidence] = this._loadedTokenPredictions.shift()[1]; nextToken = token; yieldRes.token = nextToken; if (probabilities != null) yieldRes.probabilities = reviveTokenProbabilities(probabilities); if (confidence != null) yieldRes.confidence = confidence; const resolvedGrammarEvaluationState = grammarEvaluationState instanceof Function ? grammarEvaluationState() : grammarEvaluationState; if (resolvedGrammarEvaluationState != null) LlamaSampler._acceptTokenOnGrammarEvaluationState(this._context._llama, resolvedGrammarEvaluationState, nextToken); this._unusedTokenPredictions--; this._usedTokenPredictions++; } else if (this._tokenPredictorOwner === tokenPredictorOwner && this._loadedTokenPredictions.length > 0) { const deleteStartIndex = Math.max(0, this._nextTokenIndex - this._loadedTokenPredictions.length); await this._eraseContextTokenRanges([{ start: deleteStartIndex, end: this._nextTokenIndex }], { canResetTokenPredictor: true, canRemovePredictionTokens: true, skipLock: true }); this._loadedTokenPredictions.length = 0; } if (this._resetTokenPredictor) { await tokenPredictor.reset({ stateTokens: [...this._contextTokens, ...evalTokens], evaluateOptions: { temperature, minP, topK, topP, seed, grammarEvaluationState: grammarEvaluationState instanceof Function ? grammarEvaluationState()?.clone() : grammarEvaluationState?.clone(), repeatPenalty, tokenBias, evaluationPriority, contextShift: contextShiftOptions, yieldEogToken: true }, targetSequence: this }); this._resetTokenPredictor = false; this._tokenPredictorOwner = tokenPredictorOwner; } if (nextToken == null) { if (this._tokenPredictorOwner === tokenPredictorOwner && // prevent incurring context shifts due to token prediction validations this._nextTokenIndex + evalTokens.length < this._context.contextSize) { const testGrammarClone = grammarEvaluationState instanceof Function ? grammarEvaluationState()?.clone() : grammarEvaluationState?.clone(); for (const token of await tokenPredictor.predictTokens()) { if (testGrammarClone != null) { const canAddToken = LlamaSampler._canBeNextTokenForGrammarEvaluationState(this.model._llama, testGrammarClone, token); if (!canAddToken) break; } evalTokens.push(token); logitsArray[evalTokens.length - 1] = true; // prevent incurring context shifts due to token prediction validations if (this._nextTokenIndex + evalTokens.length >= this._context.contextSize) break; } } let resolvedGrammarEvaluationState = undefined; // Evaluate to get the next token. const decodeResult = await this._decodeTokens(evalTokens, logitsArray, evaluationPriority, this._tokenMeter, contextShiftOptions, (batchLogitIndex, tokenIndex) => { if (tokenIndex === logitsStartIndex) resolvedGrammarEvaluationState = grammarEvaluationState instanceof Function ? grammarEvaluationState() : grammarEvaluationState; else if (tokenIndex === logitsStartIndex + 1) resolvedGrammarEvaluationState = resolvedGrammarEvaluationState?.clone(); const samplerConfig = this._resolveSamplerConfig({ temperature, minP, topK, topP, seed, grammarEvaluationState: resolvedGrammarEvaluationState, repeatPenalty, tokenBias }); return withLock([sampler, "sample"], async () => { if (sampler.disposed) return null; sampler.applyConfig(samplerConfig); if (sampleProbabilities || sampleConfidence) return this._context._ctx.sampleToken(batchLogitIndex, sampler._sampler, sampleProbabilities, sampleConfidence); else return this._context._ctx.sampleToken(batchLogitIndex, sampler._sampler); }); }); for (let i = logitsStartIndex; i < evalTokens.length; i++) { const item = decodeResult[i]; const [resultToken, probabilities, confidence] = item instanceof Array ? item : [item]; if (i === logitsStartIndex) { if (resultToken === -1) throw new Error("Failed to sample next token"); if (resultToken == null) return; nextToken = resultToken; yieldRes.token = nextToken; if (probabilities != null) yieldRes.probabilities = reviveTokenProbabilities(probabilities); if (confidence != null) yieldRes.confidence = confidence; } else { if (resultToken === -1 || resultToken == null) break; const lastValidatedTokenOutput = i === logitsStartIndex + 1 ? nextToken : validatedTokens.at(-1)?.[1]; if (lastValidatedTokenOutput != null && lastValidatedTokenOutput === evalTokens[i]) { this._loadedTokenPredictions.push([evalTokens[i], [resultToken, probabilities, confidence]]); this._validatedTokenPredictions++; this._unusedTokenPredictions++; } else { const deleteSize = Math.min(evalTokens.length - i, this.context.contextSize); this._refutedTokenPredictions += deleteSize; const deleteStartIndex = this._nextTokenIndex - deleteSize; tokenPredictor.stop(true); await this._eraseContextTokenRanges([{ start: deleteStartIndex, end: this._nextTokenIndex }], { canResetTokenPredictor: false, canRemovePredictionTokens: false, skipLock: true }); break; // the assumption that this token will be generated was wrong } } } } if (nextToken == null) throw new Error("Failed to generated next token"); // the model finished generating text if (!yieldEogToken && this._context.model.isEogToken(nextToken)) break; } finally { evaluatorLock.dispose(); } const replacementToken = yield yieldRes; // set the tokens for the next evaluation if (replacementToken instanceof Array) evalTokens = replacementToken.slice(); else if (replacementToken != null) evalTokens = [replacementToken]; else evalTokens = [nextToken]; if (this._tokenPredictorOwner === tokenPredictorOwner) tokenPredictor.pushTokens(evalTokens); logitsArray = []; logitsStartIndex = evalTokens.length - 1; logitsArray[logitsStartIndex] = true; } } finally { void withLock([sampler, "sample"], sampler.asyncDispose); if (this._tokenPredictorOwner === tokenPredictorOwner) tokenPredictor.stop(); } } /** @internal */ async _abortTokenPredictor(skipClearingPredictionsFromState = false, skipLock = false) { this._tokenPredictor?.stop(); this._resetTokenPredictor = true; if (skipClearingPredictionsFromState) return; if (this._loadedTokenPredictions.length > 0) await this._eraseContextTokenRanges([{ start: this._nextTokenIndex - this._loadedTokenPredictions.length, end: this._nextTokenIndex }], { canResetTokenPredictor: true, canRemovePredictionTokens: true, skipLock }); } /** @internal */ _resolveSamplerConfig({ temperature = 0, minP = 0, topK = 40, topP = 0.95, seed, grammarEvaluationState, repeatPenalty, tokenBias }) { const repeatPenaltyTokens = repeatPenalty?.punishTokens instanceof Function ? repeatPenalty.punishTokens() : repeatPenalty?.punishTokens; const maxPunishTokens = Math.max(repeatPenalty?.maxPunishTokens ?? defaultMaxPunishTokens, repeatPenaltyTokens?.length ?? 0); const resolvedGrammarEvaluationState = grammarEvaluationState instanceof Function ? grammarEvaluationState() : grammarEvaluationState; if (resolvedGrammarEvaluationState != null && resolvedGrammarEvaluationState._llama !== this.model._llama) throw new Error("The LlamaGrammar used by passed to this function was created with a different Llama instance than the one used by this sequence's model. Make sure you use the same Llama instance for both the model and the grammar."); const { tokenBiasKeys, tokenBiasValues } = getTokenBiasesForAddon(tokenBias, this.model); return removeNullFields({ temperature, minP, topK, topP, seed: Math.max(0, Number.isFinite(seed) ? Math.floor(seed ?? (Date.now() / 1000)) : Math.floor(Date.now() / 1000)), repeatPenalty: repeatPenalty?.penalty, repeatPenaltyMaxTokens: maxPunishTokens, repeatPenaltyTokens: repeatPenaltyTokens != null ? Uint32Array.from(repeatPenaltyTokens) : undefined, repeatPenaltyPresencePenalty: repeatPenalty?.presencePenalty, repeatPenaltyFrequencyPenalty: repeatPenalty?.frequencyPenalty, tokenBiasKeys, tokenBiasValues, grammarEvaluationState: resolvedGrammarEvaluationState?._state }); } /** * The caller of this function has to wrap it with a lock to ensure this function doesn't run concurrently. * @internal */ async _decodeTokens(tokens, logits, evaluationPriority, tokenMeter, contextShiftOptions, logitDataMapper) { this._ensureNotDisposed(); const tokensLeftToDecode = tokens.slice(); const tokenLogitsLeftToDecode = logits.slice(); let currentTokenIndex = 0; const res = []; const normalizedLogitDataMapper = (batchLogitIndex, contextStateTokenIndex) => { return logitDataMapper(batchLogitIndex, currentTokenIndex + (contextStateTokenIndex - this._nextTokenIndex)); }; while (tokensLeftToDecode.length > 0) { this._ensureNotDisposed(); let freeSpace = this._context.contextSize - 1 - this._nextTokenIndex; if (freeSpace <= 0) { await this._freeUpSpaceForTokens(contextShiftOptions); freeSpace = this._context.contextSize - 1 - this._nextTokenIndex; if (freeSpace <= 0) throw new Error("Failed to free up space for new tokens"); } const tokensToDecode = tokensLeftToDecode.splice(0, freeSpace); const tokensLogits = tokenLogitsLeftToDecode.slice(0, tokensToDecode.length); const generatedLogits = await this._context._decodeTokens({ sequenceId: this._sequenceId, tokens: tokensToDecode, firstTokenSequenceIndex: this._nextTokenIndex, logits: tokensLogits, evaluationPriority, tokenMeter }, normalizedLogitDataMapper); for (const [index, value] of generatedLogits) res[currentTokenIndex + (index - this._nextTokenIndex)] = value; this._nextTokenIndex += tokensToDecode.length; currentTokenIndex += tokensToDecode.length; this._contextTokens = this._contextTokens.concat(tokensToDecode); } return res; } /** @internal */ async _freeUpSpaceForTokens(contextShiftOptions) { this._ensureNotDisposed(); const size = Math.min(this._nextTokenIndex, Math.max(1, contextShiftOptions.size instanceof Function ? await contextShiftOptions.size(this) : contextShiftOptions.size)); this._ensureNotDisposed(); if (contextShiftOptions.strategy === "eraseBeginning") { let eraseStartIndex = 0; if (this.model.tokens.bos != null && this._contextTokens[0] === this.model.tokens.bos) eraseStartIndex = 1; await this._eraseContextTokenRanges([{ start: eraseStartIndex, end: size + eraseStartIndex }], { skipLock: true }); } else { const ranges = await contextShiftOptions.strategy({ sequence: this, size }); if (ranges == null) throw new Error("Invalid delete ranges"); await this._eraseContextTokenRanges(ranges, { skipLock: true }); if (this._nextTokenIndex >= this._context.contextSize - 1) await this._eraseContextTokenRanges([{ start: 0, end: size }], { skipLock: true }); } } /** @internal */ _ensureNotDisposed() { if (this._disposed) throw new DisposedError(); } /** * We need this to make it impossible to manually create instances of this class outside the code of this library * @internal */ static _create({ sequenceId, context, tokenMeter, contextShift: { size: contextShiftSize = Math.min(100, Math.ceil(context.contextSize / 2)), strategy: contextShiftStrategy = "eraseBeginning" } = {}, tokenPredictor }) { return new LlamaContextSequence({ sequenceId, context, tokenMeter, contextShift: { size: contextShiftSize, strategy: contextShiftStrategy }, tokenPredictor }); } } function getTokenBiasesForAddon(tokenBias, currentModel) { if (tokenBias == null) return { tokenBiasKeys: undefined, tokenBiasValues: undefined }; if (tokenBias instanceof Function) tokenBias = tokenBias(); if (tokenBias._tokenizer !== currentModel.tokenizer) throw new Error("This TokenBias instance was created with a different model than the one used by this context. " + "Make sure you use the model instance of the context sequence for the TokenBias you use it with."); const tokenBiasKeys = []; const tokenBiasValues = []; for (const [token, bias] of tokenBias._biases) { tokenBiasKeys.push(token); tokenBiasValues.push(bias); } if (tokenBiasKeys.length === 0 || tokenBiasValues.length === 0) { return { tokenBiasKeys: undefined, tokenBiasValues: undefined }; } return { tokenBiasKeys: Uint32Array.from(tokenBiasKeys), tokenBiasValues: Float32Array.from(tokenBiasValues) }; } function reviveTokenProbabilities(probabilities) { if (probabilities == null) return undefined; const res = new Map(); for (let i = 1; i < probabilities.length; i += 2) { const token = probabilities[i - 1]; const probability = probabilities[i]; res.set(token, probability); } return res; } function disposeContextIfReferenced(contextRef) { const context = contextRef.deref(); if (context != null) void context.dispose(); } function disposeContextSequenceIfReferenced(contextRef) { const context = contextRef.deref(); if (context != null) context.dispose(); } export function getDefaultContextBatchSize({ contextSize, sequences }) { return Math.min(contextSize * sequences, 512); } export function getDefaultContextSequences() { return 1; } const defaultFallbackContextSize = 4096; export function getDefaultModelContextSize({ trainContextSize }) { return trainContextSize ?? defaultFallbackContextSize; } //# sourceMappingURL=LlamaContext.js.map