Created
February 27, 2024 20:10
-
-
Save dwcarr/f52b517ba83db0ef1863323f60e5376b to your computer and use it in GitHub Desktop.
Langchain js custom chat model for Deep Infra
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import type { BaseChatModelParams } from "@langchain/core/language_models/chat_models"; | |
import { | |
type OpenAIClient, | |
type ChatOpenAICallOptions, | |
type OpenAIChatInput, | |
type OpenAICoreRequestOptions, | |
ChatOpenAI, | |
} from "@langchain/openai"; | |
import { getEnvironmentVariable } from "@langchain/core/utils/env"; | |
type DeepInfraUnsupportedArgs = | |
| "frequencyPenalty" | |
| "presencePenalty" | |
| "logitBias" | |
| "functions"; | |
type DeepInfraUnsupportedCallOptions = "functions" | "function_call"; | |
export interface ChatDeepInfraCallOptions | |
extends Omit<ChatOpenAICallOptions, DeepInfraUnsupportedCallOptions> { | |
response_format: { | |
type: "json_object"; | |
schema: Record<string, unknown>; | |
}; | |
} | |
export interface ChatDeepInfraInput | |
extends Omit<OpenAIChatInput, "openAIApiKey" | DeepInfraUnsupportedArgs>, | |
BaseChatModelParams { | |
/** | |
* The TogetherAI API key to use for requests. | |
* @default process.env.DEEP_INFRA_API_KEY | |
*/ | |
deepInfraApiKey?: string; | |
} | |
/** | |
* Wrapper around Deep infra API for large language models fine-tuned for chat | |
* | |
* Deep Infra API is compatible to the OpenAI API with some limitations. View the | |
* full API ref at: | |
* @link {https://deepinfra.com/docs/advanced/openai_api} | |
* | |
* To use, you should have the `DEEP_INFRA_API_KEY` environment variable set. | |
* @example | |
* ```typescript | |
* const model = new ChatDeepInfra({ | |
* temperature: 0.9, | |
* deepInfraApiKey: process.env.DEEP_INFRA_API_KEY, | |
* }); | |
* | |
* const response = await model.invoke([new HumanMessage("Hello there!")]); | |
* console.log(response); | |
* ``` | |
*/ | |
export class ChatDeepInfra extends ChatOpenAI<ChatDeepInfraCallOptions> { | |
static lc_name() { | |
return "ChatDeepInfra"; | |
} | |
_llmType() { | |
return "deepInfra"; | |
} | |
get lc_secrets(): { [key: string]: string } | undefined { | |
return { | |
deepInfraApiKey: "DEEP_INFRA_API_KEY", | |
}; | |
} | |
lc_serializable = true; | |
constructor( | |
fields?: Partial< | |
Omit<OpenAIChatInput, "openAIApiKey" | DeepInfraUnsupportedArgs> | |
> & | |
BaseChatModelParams & { deepInfraApiKey?: string } | |
) { | |
const deepInfraApiKey = | |
fields?.deepInfraApiKey || getEnvironmentVariable("DEEP_INFRA_API_KEY"); | |
if (!deepInfraApiKey) { | |
throw new Error( | |
`Deep Infra API key not found. Please set the DEEP_INFRA_API_KEY environment variable or provide the key into "deepInfraApiKey"` | |
); | |
} | |
super({ | |
...fields, | |
modelName: fields?.modelName || "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
openAIApiKey: deepInfraApiKey, | |
configuration: { | |
baseURL: "https://api.deepinfra.com/v1/openai/", | |
}, | |
}); | |
} | |
toJSON() { | |
const result = super.toJSON(); | |
if ( | |
"kwargs" in result && | |
typeof result.kwargs === "object" && | |
result.kwargs != null | |
) { | |
delete result.kwargs.openai_api_key; | |
delete result.kwargs.configuration; | |
} | |
return result; | |
} | |
async completionWithRetry( | |
request: OpenAIClient.Chat.ChatCompletionCreateParamsStreaming, | |
options?: OpenAICoreRequestOptions | |
): Promise<AsyncIterable<OpenAIClient.Chat.Completions.ChatCompletionChunk>>; | |
async completionWithRetry( | |
request: OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming, | |
options?: OpenAICoreRequestOptions | |
): Promise<OpenAIClient.Chat.Completions.ChatCompletion>; | |
/** | |
* Calls the TogetherAI API with retry logic in case of failures. | |
* @param request The request to send to the TogetherAI API. | |
* @param options Optional configuration for the API call. | |
* @returns The response from the TogetherAI API. | |
*/ | |
async completionWithRetry( | |
request: | |
| OpenAIClient.Chat.ChatCompletionCreateParamsStreaming | |
| OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming, | |
options?: OpenAICoreRequestOptions | |
): Promise< | |
| AsyncIterable<OpenAIClient.Chat.Completions.ChatCompletionChunk> | |
| OpenAIClient.Chat.Completions.ChatCompletion | |
> { | |
delete request.frequency_penalty; | |
delete request.presence_penalty; | |
delete request.logit_bias; | |
delete request.functions; | |
if (request.stream === true) { | |
return super.completionWithRetry(request, options); | |
} | |
return super.completionWithRetry(request, options); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment