Skip to content

Commit

Permalink
Use a common http client for LLM upstream calls (#236)
Browse files Browse the repository at this point in the history
  • Loading branch information
crspeller authored Aug 21, 2024
1 parent d1070ee commit b3772f4
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 17 deletions.
5 changes: 3 additions & 2 deletions server/ai/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package anthropic

import (
"fmt"
"net/http"

"github.com/mattermost/mattermost-plugin-ai/server/ai"
"github.com/mattermost/mattermost-plugin-ai/server/metrics"
Expand All @@ -16,8 +17,8 @@ type Anthropic struct {
metricsService metrics.LLMetrics
}

func New(llmService ai.ServiceConfig, metricsService metrics.LLMetrics) *Anthropic {
client := NewClient(llmService.APIKey)
func New(llmService ai.ServiceConfig, httpClient *http.Client, metricsService metrics.LLMetrics) *Anthropic {
client := NewClient(llmService.APIKey, httpClient)

return &Anthropic{
client: client,
Expand Down
6 changes: 3 additions & 3 deletions server/ai/anthropic/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ type MessageStreamEvent struct {

type Client struct {
apiKey string
httpClient http.Client
httpClient *http.Client
}

func NewClient(apiKey string) *Client {
func NewClient(apiKey string, httpClient *http.Client) *Client {
return &Client{
apiKey: apiKey,
httpClient: http.Client{},
httpClient: httpClient,
}
}

Expand Down
5 changes: 3 additions & 2 deletions server/ai/asksage/asksage.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package asksage

import (
"net/http"
"strings"

"github.com/mattermost/mattermost-plugin-ai/server/ai"
Expand All @@ -14,8 +15,8 @@ type AskSage struct {
metric metrics.LLMetrics
}

func New(llmService ai.ServiceConfig, metric metrics.LLMetrics) *AskSage {
client := NewClient("")
func New(llmService ai.ServiceConfig, httpClient *http.Client, metric metrics.LLMetrics) *AskSage {
client := NewClient("", httpClient)
client.Login(GetTokenParams{
Email: llmService.Username,
Password: llmService.Password,
Expand Down
4 changes: 2 additions & 2 deletions server/ai/asksage/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ type Persona struct {

type Dataset string

func NewClient(authToken string) *Client {
func NewClient(authToken string, httpClient *http.Client) *Client {
return &Client{
AuthToken: authToken,
HTTPClient: &http.Client{},
HTTPClient: httpClient,
}
}

Expand Down
7 changes: 5 additions & 2 deletions server/ai/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"image"
"image/png"
"io"
"net/http"
"net/url"
"strings"
"time"
Expand Down Expand Up @@ -38,12 +39,13 @@ const OpenAIMaxImageSize = 20 * 1024 * 1024 // 20 MB

var ErrStreamingTimeout = errors.New("timeout streaming")

func NewCompatible(llmService ai.ServiceConfig, metricsService metrics.LLMetrics) *OpenAI {
func NewCompatible(llmService ai.ServiceConfig, httpClient *http.Client, metricsService metrics.LLMetrics) *OpenAI {
apiKey := llmService.APIKey
endpointURL := strings.TrimSuffix(llmService.APIURL, "/")
defaultModel := llmService.DefaultModel
config := openaiClient.DefaultConfig(apiKey)
config.BaseURL = endpointURL
config.HTTPClient = httpClient

parsedURL, err := url.Parse(endpointURL)
if err == nil && strings.HasSuffix(parsedURL.Host, "openai.azure.com") {
Expand All @@ -64,13 +66,14 @@ func NewCompatible(llmService ai.ServiceConfig, metricsService metrics.LLMetrics
}
}

func New(llmService ai.ServiceConfig, metricsService metrics.LLMetrics) *OpenAI {
func New(llmService ai.ServiceConfig, httpClient *http.Client, metricsService metrics.LLMetrics) *OpenAI {
defaultModel := llmService.DefaultModel
if defaultModel == "" {
defaultModel = openaiClient.GPT3Dot5Turbo
}
config := openaiClient.DefaultConfig(llmService.APIKey)
config.OrgID = llmService.OrgID
config.HTTPClient = httpClient

streamingTimeout := StreamingTimeoutDefault
if llmService.StreamingTimeoutSeconds > 0 {
Expand Down
19 changes: 13 additions & 6 deletions server/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"os/exec"
"sync"
"time"

"errors"

Expand All @@ -22,6 +23,7 @@ import (
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/plugin"
"github.com/mattermost/mattermost/server/public/pluginapi"
"github.com/mattermost/mattermost/server/public/shared/httpservice"
"github.com/nicksnyder/go-i18n/v2/i18n"
)

Expand Down Expand Up @@ -72,6 +74,8 @@ type Plugin struct {
bots []*Bot

i18n *i18n.Bundle

llmUpstreamHTTPClient *http.Client
}

func resolveffmpegPath() string {
Expand Down Expand Up @@ -100,6 +104,9 @@ func (p *Plugin) OnActivate() error {

p.i18n = i18nInit()

p.llmUpstreamHTTPClient = httpservice.MakeHTTPServicePlugin(p.API).MakeClient(true)
p.llmUpstreamHTTPClient.Timeout = time.Minute * 10 // LLM requests can be slow

if err := p.MigrateServicesToBots(); err != nil {
p.pluginAPI.Log.Error("failed to migrate services to bots", "error", err)
// Don't fail on migration errors
Expand Down Expand Up @@ -144,13 +151,13 @@ func (p *Plugin) getLLM(llmBotConfig ai.BotConfig) ai.LanguageModel {
var llm ai.LanguageModel
switch llmBotConfig.Service.Type {
case "openai":
llm = openai.New(llmBotConfig.Service, llmMetrics)
llm = openai.New(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "openaicompatible":
llm = openai.NewCompatible(llmBotConfig.Service, llmMetrics)
llm = openai.NewCompatible(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "anthropic":
llm = anthropic.New(llmBotConfig.Service, llmMetrics)
llm = anthropic.New(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "asksage":
llm = asksage.New(llmBotConfig.Service, llmMetrics)
llm = asksage.New(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
}

cfg := p.getConfiguration()
Expand All @@ -175,9 +182,9 @@ func (p *Plugin) getTranscribe() ai.Transcriber {
llmMetrics := p.metricsService.GetMetricsForAIService(botConfig.Name)
switch botConfig.Service.Type {
case "openai":
return openai.New(botConfig.Service, llmMetrics)
return openai.New(botConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "openaicompatible":
return openai.NewCompatible(botConfig.Service, llmMetrics)
return openai.NewCompatible(botConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
}
return nil
}
Expand Down

0 comments on commit b3772f4

Please sign in to comment.