212 lines
9.5 KiB
JavaScript
212 lines
9.5 KiB
JavaScript
import { SpecialToken, LlamaText, SpecialTokensText } from "../../utils/LlamaText.js";
|
|
import { ChatWrapper } from "../../ChatWrapper.js";
|
|
import { parseTextTemplate } from "../../utils/parseTextTemplate.js";
|
|
import { parseFunctionCallMessageTemplate } from "./utils/chatHistoryFunctionCallMessageTemplate.js";
|
|
import { templateSegmentOptionsToChatWrapperSettings } from "./utils/templateSegmentOptionsToChatWrapperSettings.js";
|
|
/**
|
|
* A chat wrapper based on a simple template.
|
|
* @example
|
|
* <span v-pre>
|
|
*
|
|
* ```ts
|
|
* import {TemplateChatWrapper} from "node-llama-cpp";
|
|
*
|
|
* const chatWrapper = new TemplateChatWrapper({
|
|
* template: "{{systemPrompt}}\n{{history}}model: {{completion}}\nuser: ",
|
|
* historyTemplate: {
|
|
* system: "system: {{message}}\n",
|
|
* user: "user: {{message}}\n",
|
|
* model: "model: {{message}}\n"
|
|
* },
|
|
* // functionCallMessageTemplate: { // optional
|
|
* // call: "[[call: {{functionName}}({{functionParams}})]]",
|
|
* // result: " [[result: {{functionCallResult}}]]"
|
|
* // },
|
|
* // segments: {
|
|
* // thoughtTemplate: "<think>{{content}}</think>",
|
|
* // reopenThoughtAfterFunctionCalls: true
|
|
* // }
|
|
* });
|
|
* ```
|
|
*
|
|
* </span>
|
|
*
|
|
* **<span v-pre>`{{systemPrompt}}`</span>** is optional and is replaced with the first system message
|
|
* (when is does, that system message is not included in the history).
|
|
*
|
|
* **<span v-pre>`{{history}}`</span>** is replaced with the chat history.
|
|
* Each message in the chat history is converted using the template passed to `historyTemplate` for the message role,
|
|
* and all messages are joined together.
|
|
*
|
|
* **<span v-pre>`{{completion}}`</span>** is where the model's response is generated.
|
|
* The text that comes after <span v-pre>`{{completion}}`</span> is used to determine when the model has finished generating the response,
|
|
* and thus is mandatory.
|
|
*
|
|
* **`functionCallMessageTemplate`** is used to specify the format in which functions can be called by the model and
|
|
* how their results are fed to the model after the function call.
|
|
*
|
|
* **`segments`** is used to specify the format of the segments generated by the model (like thought segments).
|
|
*/
|
|
export class TemplateChatWrapper extends ChatWrapper {
|
|
wrapperName = "Template";
|
|
settings;
|
|
template;
|
|
historyTemplate;
|
|
joinAdjacentMessagesOfTheSameType;
|
|
/** @internal */ _parsedChatTemplate;
|
|
/** @internal */ _parsedChatHistoryTemplate;
|
|
constructor({ template, historyTemplate, functionCallMessageTemplate, joinAdjacentMessagesOfTheSameType = true, segments }) {
|
|
super();
|
|
if (template == null || historyTemplate == null)
|
|
throw new Error("Template chat wrapper settings must have a template and historyTemplate.");
|
|
if (historyTemplate.system == null || historyTemplate.user == null || historyTemplate.model == null)
|
|
throw new Error("Template chat wrapper historyTemplate must have system, user, and model templates.");
|
|
this.template = template;
|
|
this.historyTemplate = historyTemplate;
|
|
this.joinAdjacentMessagesOfTheSameType = joinAdjacentMessagesOfTheSameType;
|
|
this._parsedChatTemplate = parseChatTemplate(template);
|
|
this._parsedChatHistoryTemplate = {
|
|
system: parseChatHistoryTemplate(historyTemplate.system),
|
|
user: parseChatHistoryTemplate(historyTemplate.user),
|
|
model: parseChatHistoryTemplate(historyTemplate.model)
|
|
};
|
|
this.settings = {
|
|
...ChatWrapper.defaultSettings,
|
|
functions: parseFunctionCallMessageTemplate(functionCallMessageTemplate) ?? ChatWrapper.defaultSettings.functions,
|
|
segments: templateSegmentOptionsToChatWrapperSettings(segments)
|
|
};
|
|
}
|
|
generateContextState({ chatHistory, availableFunctions, documentFunctionParams }) {
|
|
const historyWithFunctions = this.addAvailableFunctionsSystemMessageToHistory(chatHistory, availableFunctions, {
|
|
documentParams: documentFunctionParams
|
|
});
|
|
const resultItems = [];
|
|
const systemTexts = [];
|
|
const userTexts = [];
|
|
const modelTexts = [];
|
|
let currentAggregateFocus = null;
|
|
function flush() {
|
|
if (systemTexts.length > 0 || userTexts.length > 0 || modelTexts.length > 0)
|
|
resultItems.push({
|
|
system: LlamaText.joinValues("\n\n", systemTexts),
|
|
user: LlamaText.joinValues("\n\n", userTexts),
|
|
model: LlamaText.joinValues("\n\n", modelTexts)
|
|
});
|
|
systemTexts.length = 0;
|
|
userTexts.length = 0;
|
|
modelTexts.length = 0;
|
|
}
|
|
for (const item of historyWithFunctions) {
|
|
if (item.type === "system") {
|
|
if (!this.joinAdjacentMessagesOfTheSameType || currentAggregateFocus !== "system")
|
|
flush();
|
|
currentAggregateFocus = "system";
|
|
systemTexts.push(LlamaText.fromJSON(item.text));
|
|
}
|
|
else if (item.type === "user") {
|
|
if (!this.joinAdjacentMessagesOfTheSameType || (currentAggregateFocus !== "system" && currentAggregateFocus !== "user"))
|
|
flush();
|
|
currentAggregateFocus = "user";
|
|
userTexts.push(LlamaText(item.text));
|
|
}
|
|
else if (item.type === "model") {
|
|
if (!this.joinAdjacentMessagesOfTheSameType)
|
|
flush();
|
|
currentAggregateFocus = "model";
|
|
modelTexts.push(this.generateModelResponseText(item.response));
|
|
}
|
|
else
|
|
void item;
|
|
}
|
|
flush();
|
|
const getHistoryItem = (role, text, prefix) => {
|
|
const { messagePrefix, messageSuffix } = this._parsedChatHistoryTemplate[role];
|
|
return LlamaText([
|
|
new SpecialTokensText((prefix ?? "") + messagePrefix),
|
|
text,
|
|
new SpecialTokensText(messageSuffix)
|
|
]);
|
|
};
|
|
const contextText = LlamaText(resultItems.map(({ system, user, model }, index) => {
|
|
const isFirstItem = index === 0;
|
|
const isLastItem = index === resultItems.length - 1;
|
|
const res = LlamaText([
|
|
isFirstItem
|
|
? system.values.length === 0
|
|
? new SpecialTokensText((this._parsedChatTemplate.systemPromptPrefix ?? "") + this._parsedChatTemplate.historyPrefix)
|
|
: this._parsedChatTemplate.systemPromptPrefix != null
|
|
? LlamaText([
|
|
new SpecialTokensText(this._parsedChatTemplate.systemPromptPrefix),
|
|
system,
|
|
new SpecialTokensText(this._parsedChatTemplate.historyPrefix)
|
|
])
|
|
: getHistoryItem("system", system, this._parsedChatTemplate.historyPrefix)
|
|
: system.values.length === 0
|
|
? LlamaText([])
|
|
: getHistoryItem("system", system),
|
|
user.values.length === 0
|
|
? LlamaText([])
|
|
: getHistoryItem("user", user),
|
|
model.values.length === 0
|
|
? LlamaText([])
|
|
: !isLastItem
|
|
? getHistoryItem("model", model)
|
|
: LlamaText([
|
|
new SpecialTokensText(this._parsedChatTemplate.completionPrefix),
|
|
model
|
|
])
|
|
]);
|
|
return LlamaText(res.values.reduce((res, value) => {
|
|
if (value instanceof SpecialTokensText) {
|
|
const lastItem = res[res.length - 1];
|
|
if (lastItem == null || !(lastItem instanceof SpecialTokensText))
|
|
return res.concat([value]);
|
|
return res.slice(0, -1).concat([
|
|
new SpecialTokensText(lastItem.value + value.value)
|
|
]);
|
|
}
|
|
return res.concat([value]);
|
|
}, []));
|
|
}));
|
|
return {
|
|
contextText,
|
|
stopGenerationTriggers: [
|
|
LlamaText(new SpecialToken("EOS")),
|
|
LlamaText(this._parsedChatTemplate.completionSuffix),
|
|
LlamaText(new SpecialTokensText(this._parsedChatTemplate.completionSuffix))
|
|
]
|
|
};
|
|
}
|
|
}
|
|
function parseChatTemplate(template) {
|
|
const parsedTemplate = parseTextTemplate(template, [{
|
|
text: "{{systemPrompt}}",
|
|
key: "systemPrompt",
|
|
optional: true
|
|
}, {
|
|
text: "{{history}}",
|
|
key: "history"
|
|
}, {
|
|
text: "{{completion}}",
|
|
key: "completion"
|
|
}]);
|
|
if (parsedTemplate.completion.suffix.length == 0)
|
|
throw new Error('Chat template must have text after "{{completion}}"');
|
|
return {
|
|
systemPromptPrefix: parsedTemplate.systemPrompt?.prefix ?? null,
|
|
historyPrefix: parsedTemplate.history.prefix,
|
|
completionPrefix: parsedTemplate.completion.prefix,
|
|
completionSuffix: parsedTemplate.completion.suffix
|
|
};
|
|
}
|
|
function parseChatHistoryTemplate(template) {
|
|
const parsedTemplate = parseTextTemplate(template, [{
|
|
text: "{{message}}",
|
|
key: "message"
|
|
}]);
|
|
return {
|
|
messagePrefix: parsedTemplate.message.prefix,
|
|
messageSuffix: parsedTemplate.message.suffix
|
|
};
|
|
}
|
|
//# sourceMappingURL=TemplateChatWrapper.js.map
|