Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🤖 feat: Add titling to Google client #2983

Merged
merged 4 commits into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ GOOGLE_KEY=user_provided
# Vertex AI
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0514,gemini-1.0-pro-vision-001,gemini-1.0-pro-002,gemini-1.0-pro-001,gemini-pro-vision,gemini-1.0-pro

# GOOGLE_TITLE_MODEL=gemini-pro

# Google Gemini Safety Settings
# NOTE (Vertex AI): You do not have access to the BLOCK_NONE setting by default.
# To use this restricted HarmBlockThreshold setting, you will need to either:
Expand Down
131 changes: 129 additions & 2 deletions api/app/clients/GoogleClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@ const {
AuthKeys,
} = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images');
const { formatMessage, createContextHandlers } = require('./prompts');
const { getModelMaxTokens } = require('~/utils');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
const {
formatMessage,
createContextHandlers,
titleInstruction,
truncateText,
} = require('./prompts');
const BaseClient = require('./BaseClient');

const loc = 'us-central1';
const publisher = 'google';
Expand Down Expand Up @@ -591,12 +596,16 @@ class GoogleClient extends BaseClient {
createLLM(clientOptions) {
const model = clientOptions.modelName ?? clientOptions.model;
if (this.project_id && this.isTextModel) {
logger.info('Creating Google VertexAI client');
mungewrath marked this conversation as resolved.
Show resolved Hide resolved
return new GoogleVertexAI(clientOptions);
} else if (this.project_id && this.isChatModel) {
logger.info('Creating Chat Google VertexAI client');
return new ChatGoogleVertexAI(clientOptions);
} else if (this.project_id) {
logger.info('Creating VertexAI client');
return new ChatVertexAI(clientOptions);
} else if (model.includes('1.5')) {
logger.info('Creating GenAI client');
return new GenAI(this.apiKey).getGenerativeModel(
{
...clientOptions,
Expand All @@ -606,6 +615,7 @@ class GoogleClient extends BaseClient {
);
}

logger.info('Creating Chat Google Generative AI client');
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
}

Expand Down Expand Up @@ -717,6 +727,123 @@ class GoogleClient extends BaseClient {
return reply;
}

/**
* Stripped-down logic for generating a title. This uses the non-streaming APIs, since the user does not see titles streaming
*/
async titleChatCompletion(_payload, options = {}) {
mungewrath marked this conversation as resolved.
Show resolved Hide resolved
const { abortController } = options;
const { parameters, instances } = _payload;
const { messages: _messages, examples: _examples } = instances?.[0] ?? {};

let clientOptions = { ...parameters, maxRetries: 2 };

logger.info('Initialized title client options');

if (this.project_id) {
clientOptions['authOptions'] = {
credentials: {
...this.serviceKey,
},
projectId: this.project_id,
};
}

if (!parameters) {
clientOptions = { ...clientOptions, ...this.modelOptions };
}

if (this.isGenerativeModel && !this.project_id) {
clientOptions.modelName = clientOptions.model;
delete clientOptions.model;
}

const model = this.createLLM(clientOptions);

let reply = '';
const messages = this.isTextModel ? _payload.trim() : _messages;

const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
if (modelName?.includes('1.5') && !this.project_id) {
logger.info('Identified titling model as 1.5 version');
/** @type {GenerativeModel} */
const client = model;
const requestOptions = {
contents: _payload,
};

if (this.options?.promptPrefix?.length) {
requestOptions.systemInstruction = {
parts: [
{
text: this.options.promptPrefix,
},
],
};
}

const safetySettings = _payload.safetySettings;
requestOptions.safetySettings = safetySettings;

const result = await client.generateContent(requestOptions);

reply = result.response?.text();

return reply;
} else {
logger.info('Beginning titling');
const safetySettings = _payload.safetySettings;

const titleResponse = await model.invoke(messages, {
signal: abortController.signal,
timeout: 7000,
safetySettings: safetySettings,
});

reply = titleResponse.content;

return reply;
}
}

async titleConvo({ text, responseText = '' }) {
let title = 'New Chat';
const convo = `||>User:
"${truncateText(text)}"
||>Response:
"${JSON.stringify(truncateText(responseText))}"`;

let { prompt: payload } = await this.buildMessages([
{
text: `Please generate ${titleInstruction}

${convo}

||>Title:`,
isCreatedByUser: true,
author: this.userLabel,
},
]);

if (this.isVisionModel) {
logger.warn(
`Current vision model does not support titling without an attachment; falling back to default model ${settings.model.default}`,
);

payload.parameters = { ...payload.parameters, model: settings.model.default };
}

try {
title = await this.titleChatCompletion(payload, {
abortController: new AbortController(),
onProgress: () => {},
});
} catch (e) {
logger.error('[GoogleClient] There was an issue generating the title', e);
}
logger.info(`Title response: ${title}`);
return title;
}

getSaveOptions() {
return {
promptPrefix: this.options.promptPrefix,
Expand Down
4 changes: 2 additions & 2 deletions api/server/routes/ask/google.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
const express = require('express');
const AskController = require('~/server/controllers/AskController');
const { initializeClient } = require('~/server/services/Endpoints/google');
const { initializeClient, addTitle } = require('~/server/services/Endpoints/google');
const {
setHeaders,
handleAbort,
Expand All @@ -20,7 +20,7 @@ router.post(
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient);
await AskController(req, res, next, initializeClient, addTitle);
},
);

Expand Down
44 changes: 44 additions & 0 deletions api/server/services/Endpoints/google/addTitle.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
const { CacheKeys } = require('librechat-data-provider');
const getLogStores = require('~/cache/getLogStores');
const { isEnabled } = require('~/server/utils');
const { saveConvo } = require('~/models');
const initializeClient = require('./initializeClient');

const addTitle = async (req, { text, response, client }) => {
console.log('Considering generating a title');
const { TITLE_CONVO = 'true' } = process.env ?? {};
if (!isEnabled(TITLE_CONVO)) {
return;
}

if (client.options.titleConvo === false) {
return;
}

// TODO: Generalize model name
const { GOOGLE_TITLE_MODEL } = process.env ?? {};
mungewrath marked this conversation as resolved.
Show resolved Hide resolved

const titleEndpointOptions = {
...client.options,
modelOptions: { ...client.options?.modelOptions, model: GOOGLE_TITLE_MODEL ?? 'gemini-pro' },
attachments: undefined, // After a response, this is set to an empty array which results in an error during setOptions
};

const { client: titleClient } = await initializeClient({
req,
res: response,
endpointOption: titleEndpointOptions,
});

const titleCache = getLogStores(CacheKeys.GEN_TITLE);
const key = `${req.user.id}-${response.conversationId}`;

const title = await titleClient.titleConvo({ text, responseText: response?.text });
await titleCache.set(key, title, 120000);
await saveConvo(req.user.id, {
conversationId: response.conversationId,
title,
});
};

module.exports = addTitle;
3 changes: 2 additions & 1 deletion api/server/services/Endpoints/google/index.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
const addTitle = require('./addTitle');
const buildOptions = require('./buildOptions');
const initializeClient = require('./initializeClient');

module.exports = {
// addTitle, // todo
addTitle,
buildOptions,
initializeClient,
};