Last active
November 11, 2023 15:05
-
-
Save sebsto/544df0894b5e69b5e3cb504a004f772f to your computer and use it in GitHub Desktop.
Example of code to invoke Claude v2 LLM on Amazon Bedrock in the Swift programming language
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Foundation | |
// reduce the verbosity of the AWS SDK | |
import ClientRuntime | |
SDKLoggingSystem.initialize(logLevel: .warning) | |
import AWSBedrock | |
import AWSBedrockRuntime | |
// create a Bedrock client and list available models for a provider | |
let provider = "cohere" | |
print("====== Models available for \(provider)") | |
let client = try BedrockClient(region: "us-east-1") | |
let input = ListFoundationModelsInput(byProvider : provider) | |
let output = try await client.listFoundationModels(input: input) | |
print(output.modelSummaries!.map { "\($0.modelName!) : \($0.modelId!)" }.joined(separator: "\n")) | |
print("======") | |
// create a bedrock runtime client and invoke a model | |
let modelId = "cohere.embed-english-v3" | |
let runtime = try BedrockRuntimeClient(region: "us-east-1") | |
let document = | |
""" | |
This is a document that provides context about a business domain | |
""" | |
let payload = CohereEmbedDocument(texts: [document], inputType: .searchDocument) | |
let request = InvokeModelInput(body: try payload.encode(), | |
contentType: "application/json", | |
modelId: modelId) | |
do { | |
let invokeModelOutput = try await runtime.invokeModel(input: request) | |
// print("== raw response ==") | |
// print(String(data: invokeModelOutput.body!, encoding: .utf8)!) | |
let cohereResponse = try CohereEmbedResponse(from: invokeModelOutput.body!) | |
print(cohereResponse) | |
print("======") | |
} catch { | |
print(error) | |
} | |
//https://docs.cohere.com/reference/embed | |
enum CohereEmbedInputType: String, Encodable { | |
case searchDocument = "search_document" | |
case searchQuery = "search_query" | |
case clasification = "classification" | |
case clustering = "clustering" | |
} | |
enum CohereEmbedTruncating: String, Encodable { | |
case none = "NONE" | |
case start = "START" | |
case end = "END" | |
} | |
struct CohereEmbedDocument: Encodable { | |
let texts: [String] | |
let inputType: CohereEmbedInputType | |
let truncate: CohereEmbedTruncating = .end | |
func encode() throws -> Data { | |
let encoder = JSONEncoder() | |
return try encoder.encode(self) | |
} | |
} | |
struct CohereEmbedResponse: Decodable, CustomStringConvertible { | |
let embeddings: [[Double]] | |
let id: String | |
let texts: [String] | |
init(from data: Data) throws { | |
let decoder = JSONDecoder() | |
self = try decoder.decode(CohereEmbedResponse.self, from: data) | |
} | |
var description: String { | |
guard embeddings.count > 0 else { | |
return "no embeddings" | |
} | |
let embedding = embeddings[0] | |
let elementsToShow = min(5, embedding.count) | |
return "[" + | |
embedding[0..<elementsToShow].map { String(format: "%.3f", $0) }.joined(separator: ",") | |
+ ",...] (\(embedding.count) elements)" | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment