Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Added response_format capabilities to chat completion request #402

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
@NoArgsConstructor
public class ChatCompletionRequest {

/**
* ID of the model to use.
*/
String model;

/**
* The messages to generate chat completions for, in the <a
* href="https://platform.openai.com/docs/guides/chat/introduction">chat format</a>.<br>
Expand All @@ -28,36 +23,38 @@ public class ChatCompletionRequest {
List<ChatMessage> messages;

/**
* What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower
* values like 0.2 will make it more focused and deterministic.<br>
* We generally recommend altering this or top_p but not both.
*/
Double temperature;

/**
* An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens
* with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.<br>
* We generally recommend altering this or temperature but not both.
* ID of the model to use.
*/
@JsonProperty("top_p")
Double topP;
String model;

/**
* How many chat completion chatCompletionChoices to generate for each input message.
* Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
* decreasing the model's likelihood to repeat the same line verbatim.
*/
Integer n;
@JsonProperty("frequency_penalty")
Double frequencyPenalty;

/**
* If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only <a
* href="https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format">server-sent
* events</a> as they become available, with the stream terminated by a data: [DONE] message.
* <p>An object specifying the format that the model must output.</p>
*
* <p>Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON.</p>
*
* <p><b>Important:</b> when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user message.
* Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting
* in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if
* finish_reason="length", which indicates the generation exceeded max_tokens or the conversation exceeded the max context length.</p>
*/
Boolean stream;
@JsonProperty("response_format")
ResponseFormat responseFormat;

/**
* Up to 4 sequences where the API will stop generating further tokens.
* Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100
* to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will
* vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100
* should result in a ban or exclusive selection of the relevant token.
*/
List<String> stop;
@JsonProperty("logit_bias")
Map<String, Integer> logitBias;

/**
* The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will
Expand All @@ -66,6 +63,11 @@ public class ChatCompletionRequest {
@JsonProperty("max_tokens")
Integer maxTokens;

/**
* How many chat completion chatCompletionChoices to generate for each input message.
*/
Integer n;

/**
* Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
* increasing the model's likelihood to talk about new topics.
Expand All @@ -74,38 +76,48 @@ public class ChatCompletionRequest {
Double presencePenalty;

/**
* Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
* decreasing the model's likelihood to repeat the same line verbatim.
* Up to 4 sequences where the API will stop generating further tokens.
*/
@JsonProperty("frequency_penalty")
Double frequencyPenalty;
List<String> stop;

/**
* Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100
* to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will
* vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100
* should result in a ban or exclusive selection of the relevant token.
* If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only <a
* href="https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format">server-sent
* events</a> as they become available, with the stream terminated by a data: [DONE] message.
*/
@JsonProperty("logit_bias")
Map<String, Integer> logitBias;
Boolean stream;

/**
* What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower
* values like 0.2 will make it more focused and deterministic.<br>
* We generally recommend altering this or top_p but not both.
*/
Double temperature;

/**
* A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
* An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens
* with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.<br>
* We generally recommend altering this or temperature but not both.
*/
String user;
@JsonProperty("top_p")
Double topP;

/**
* A list of the available functions.
* A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
*/
List<?> functions;
String user;

/**
* Controls how the model responds to function calls, as specified in the <a href="https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call">OpenAI documentation</a>.
*/
@JsonProperty("function_call")
ChatCompletionRequestFunctionCall functionCall;

/**
* A list of the available functions.
*/
List<?> functions;

@Data
@Builder
@AllArgsConstructor
Expand All @@ -118,4 +130,18 @@ public static ChatCompletionRequestFunctionCall of(String name) {
}

}

@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public static class ResponseFormat {
String type;

public static ResponseFormat of(String type) {
return new ResponseFormat(type);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.theokanning.openai.completion.chat.*;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.util.*;

import static org.junit.jupiter.api.Assertions.*;
Expand All @@ -23,7 +26,7 @@ static class Weather {
}

enum WeatherUnit {
CELSIUS, FAHRENHEIT;
CELSIUS, FAHRENHEIT
}

static class WeatherResponse {
Expand Down Expand Up @@ -300,4 +303,46 @@ void streamChatCompletionWithDynamicFunctions() {
assertNotNull(accumulatedMessage.getFunctionCall().getArguments().get("unit"));
}

@Test
void streamChatCompletionWithJsonResponseFormat() {
final List<ChatMessage> messages = new ArrayList<>();

// The system message is deliberately vague in order to not give too much of a direction of how response should look like.
// The main gist there is that chat competition should always contain JSON content.
final ChatMessage systemMessage = new ChatMessage(
ChatMessageRole.SYSTEM.value(),
"You are a dog and will speak as such - but please do it in JSON."
);

messages.add(systemMessage);

ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-4-1106-preview")
.messages(messages)
.n(1)
.maxTokens(256)
.responseFormat(ChatCompletionRequest.ResponseFormat.of("json_object"))
.build();

ChatCompletionResult chatCompletion = service.createChatCompletion(chatCompletionRequest);

ChatCompletionChoice chatCompletionChoice = chatCompletion.getChoices().get(0);
String expectedJsonContent = chatCompletionChoice.getMessage().getContent();

assertTrue(isValidJSON(expectedJsonContent), "Invalid JSON response:\n\n" + expectedJsonContent);
}

private boolean isValidJSON(String json) {
try (final JsonParser parser = new ObjectMapper().createParser(json)) {
while (parser.nextToken() != null) {
// Just try to read all tokens in order to verify whether this is valid json.
}
return true;
} catch (IOException ioe) {
ioe.printStackTrace();
return false;
}
}

}