Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic support for functions #60

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 141 additions & 2 deletions Sources/OpenAIKit/Chat/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,153 @@ extension Chat {
extension Chat.Choice: Codable {}

extension Chat {

public enum FunctionMode {
case none
case auto
case named(String)
}
}

extension Chat.FunctionMode: Codable {

private struct Named: Codable {
let name: String
}

public init(from decoder: Decoder) throws {
do {
let container = try decoder.singleValueContainer()
let singleStringValue = try container.decode(String.self)

switch singleStringValue {
case "none":
self = .none
case "auto":
self = .auto
default:
throw DecodingError.dataCorruptedError(
in: container,
debugDescription: "Invalid type"
)
}
} catch {
let parameterized = try Named(from: decoder)
self = .named(parameterized.name)
}
}

public func encode(to encoder: Encoder) throws {
switch self {
case .none:
var container = encoder.singleValueContainer()
try container.encode("none")
case .auto:
var container = encoder.singleValueContainer()
try container.encode("auto")
case .named(let string):
try Named(name: string).encode(to: encoder)
}
}
}

public protocol ChatFunction: Encodable {
associatedtype Parameters: Encodable

var name: String { get }
var description: String? { get }
var parameters: Parameters? { get }
}

public protocol ChatFunctionCall: Codable {
var name: String { get }
init(name: String)
}

public protocol ChatFunctionCallWithArguments: Codable {
associatedtype Arguments: Codable

var name: String { get }
var arguments: Arguments { get }

init(name: String, arguments: Arguments)
}

public struct UnstructuredChatFunctionCall: Codable {
let name: String
let arguments: String?

public func structured<T: ChatFunctionCall>(as callType: T.Type, decoder: JSONDecoder = JSONDecoder()) throws -> T {
callType.init(name: name)
}

public func structured<T: ChatFunctionCallWithArguments>(as callType: T.Type, decoder: JSONDecoder = JSONDecoder()) throws -> T {
if let arguments {
let data = Data(arguments.utf8)
let parsedArgs = try decoder.decode(callType.Arguments, from: data)
return callType.init(name: name, arguments: parsedArgs)
}
throw NSError(domain: "", code: -1)
}
}

public extension Chat {
typealias Function = ChatFunction
typealias UnstructuredFunctionCall = UnstructuredChatFunctionCall
typealias FunctionCall = ChatFunctionCall
typealias FunctionCallWithArguments = ChatFunctionCallWithArguments
}

extension Chat {

public enum Message {
case system(content: String)
case user(content: String)
case assistant(content: String)
case assistantWithCall(content: String?, call: UnstructuredFunctionCall)
case function(content: String?, name: String, call: UnstructuredFunctionCall?)
}
}

extension Chat.Message: Codable {

private enum CodingKeys: String, CodingKey {
case role
case content
case name
case functionCall
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let role = try container.decode(String.self, forKey: .role)
let content = try container.decode(String.self, forKey: .content)

switch role {
case "system":
let content = try container.decode(String.self, forKey: .content)
self = .system(content: content)
case "user":
let content = try container.decode(String.self, forKey: .content)
self = .user(content: content)
case "assistant":
self = .assistant(content: content)

if let call = try container.decodeIfPresent(UnstructuredChatFunctionCall.self, forKey: .functionCall) {
let content = try container.decodeIfPresent(String.self, forKey: .content)
self = .assistantWithCall(content: content, call: call)

} else {

let content = try container.decode(String.self, forKey: .content)
self = .assistant(content: content)
}

case "function":

let call = try container.decodeIfPresent(UnstructuredChatFunctionCall.self, forKey: .functionCall)

let content = try container.decodeIfPresent(String.self, forKey: .content)
let name = try container.decode(String.self, forKey: .name)
self = .function(content: content, name: name, call: call)
default:
throw DecodingError.dataCorruptedError(forKey: .role, in: container, debugDescription: "Invalid type")
}
Expand All @@ -66,23 +189,39 @@ extension Chat.Message: Codable {
case .assistant(let content):
try container.encode("assistant", forKey: .role)
try container.encode(content, forKey: .content)
case let .assistantWithCall(content, call):
try container.encode("assistant", forKey: .role)
try container.encode(call, forKey: .functionCall)
try container.encodeIfPresent(content, forKey: .content)
case let .function(content, name, call):
try container.encode("function", forKey: .role)
try container.encode(name, forKey: .name)
try container.encodeIfPresent(call, forKey: .functionCall)
try container.encodeIfPresent(content, forKey: .content)
}
}
}

extension Chat.Message {

public var content: String {
get {
switch self {
case .system(let content), .user(let content), .assistant(let content):
return content
case .assistantWithCall(let content, _), .function(let content, _, _):
return content ?? ""
}
}
set {
switch self {
case .system: self = .system(content: newValue)
case .user: self = .user(content: newValue)
case .assistant: self = .assistant(content: newValue)
case let .function(_, name, call):
self = .function(content: newValue, name: name, call: call)
case let .assistantWithCall(_, call):
self = .assistantWithCall(content: newValue, call: call)
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions Sources/OpenAIKit/Chat/ChatProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public struct ChatProvider {
public func create(
model: ModelID,
messages: [Chat.Message] = [],
functions: [any Chat.Function] = [],
functionMode: Chat.FunctionMode = .none,
temperature: Double = 1.0,
topP: Double = 1.0,
n: Int = 1,
Expand All @@ -31,6 +33,8 @@ public struct ChatProvider {
let request = try CreateChatRequest(
model: model.id,
messages: messages,
functions: functions,
functionMode: functionMode,
temperature: temperature,
topP: topP,
n: n,
Expand Down Expand Up @@ -63,6 +67,8 @@ public struct ChatProvider {
public func stream(
model: ModelID,
messages: [Chat.Message] = [],
functions: [any Chat.Function] = [],
functionMode: Chat.FunctionMode = .none,
temperature: Double = 1.0,
topP: Double = 1.0,
n: Int = 1,
Expand All @@ -77,6 +83,8 @@ public struct ChatProvider {
let request = try CreateChatRequest(
model: model.id,
messages: messages,
functions: functions,
functionMode: functionMode,
temperature: temperature,
topP: topP,
n: n,
Expand Down
16 changes: 16 additions & 0 deletions Sources/OpenAIKit/Chat/CreateChatRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ struct CreateChatRequest: Request {
init(
model: String,
messages: [Chat.Message],
functions: [any Chat.Function],
functionMode: Chat.FunctionMode,
temperature: Double,
topP: Double,
n: Int,
Expand All @@ -25,6 +27,8 @@ struct CreateChatRequest: Request {
let body = Body(
model: model,
messages: messages,
functions: functions,
functionCall: functionMode,
temperature: temperature,
topP: topP,
n: n,
Expand All @@ -45,6 +49,8 @@ extension CreateChatRequest {
struct Body: Encodable {
let model: String
let messages: [Chat.Message]
let functions: [any Chat.Function]
let functionCall: Chat.FunctionMode
let temperature: Double
let topP: Double
let n: Int
Expand All @@ -59,6 +65,8 @@ extension CreateChatRequest {
enum CodingKeys: CodingKey {
case model
case messages
case functions
case functionCall
case temperature
case topP
case n
Expand All @@ -78,6 +86,14 @@ extension CreateChatRequest {
if !messages.isEmpty {
try container.encode(messages, forKey: .messages)
}

if !functions.isEmpty {
var nestedContainer = container.nestedUnkeyedContainer(forKey: .functions)
try functions.forEach {
try nestedContainer.encode($0)
}
try container.encode(functionCall, forKey: .functionCall)
}

try container.encode(temperature, forKey: .temperature)
try container.encode(topP, forKey: .topP)
Expand Down
3 changes: 3 additions & 0 deletions Sources/OpenAIKit/Chat/FinishReason.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ public enum FinishReason: String {

/// Omitted content due to a flag from our content filters
case contentFilter = "content_filter"

/// ...
case functionCall = "function_call"
}

extension FinishReason: Codable {}
1 change: 1 addition & 0 deletions Sources/OpenAIKit/Model/Model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ extension Model {
public enum GPT4: String, ModelID {
case gpt4 = "gpt-4"
case gpt40314 = "gpt-4-0314"
case gpt40613 = "gpt-4-0613"
case gpt4_32k = "gpt-4-32k"
case gpt4_32k0314 = "gpt-4-32k-0314"
}
Expand Down