Skip to content

Commit

Permalink
Merge pull request #2 from ad-astra-video/rebase_tts
Browse files Browse the repository at this point in the history
Rebase tts
  • Loading branch information
pschroedl authored Oct 30, 2024
2 parents a5fc8e3 + 4e4d1dc commit f81f92d
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 53 deletions.
1 change: 1 addition & 0 deletions common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ var (
"video/mp2t": ".ts",
"video/mp4": ".mp4",
"image/png": ".png",
"audio/wav": ".wav",
}
)

Expand Down
2 changes: 1 addition & 1 deletion core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type AI interface {
LLM(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
ImageToText(context.Context, worker.GenImageToTextMultipartRequestBody) (*worker.ImageToTextResponse, error)
TextToSpeech(context.Context, worker.GenTextToSpeechJSONRequestBody) (*worker.EncodedFileResponse, error)
TextToSpeech(context.Context, worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
HasCapacity(pipeline, modelID string) bool
Expand Down
4 changes: 4 additions & 0 deletions core/ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,10 @@ func (a *stubAIWorker) ImageToText(ctx context.Context, req worker.GenImageToTex
return &worker.ImageToTextResponse{Text: "Transcribed text"}, nil
}

func (a *stubAIWorker) TextToSpeech(ctx context.Context, req worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error) {
return &worker.AudioResponse{Audio: worker.MediaURL{Url: "http://example.com/audio.wav"}}, nil
}

func (a *stubAIWorker) Warm(ctx context.Context, arg1, arg2 string, endpoint worker.RunnerEndpoint, flags worker.OptimizationFlags) error {
return nil
}
Expand Down
116 changes: 90 additions & 26 deletions core/ai_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,68 +422,95 @@ func (n *LivepeerNode) saveLocalAIWorkerResults(ctx context.Context, results int
ext, _ := common.MimeTypeToExtension(contentType)
fileName := string(RandomManifestID()) + ext

imgRes, ok := results.(worker.ImageResponse)
if !ok {
// worker.TextResponse is JSON, no file save needed
return results, nil
}
storage, exists := n.StorageConfigs[requestID]
if !exists {
return nil, errors.New("no storage available for request")
}

var buf bytes.Buffer
for i, image := range imgRes.Images {
buf.Reset()
err := worker.ReadImageB64DataUrl(image.Url, &buf)
if err != nil {
// try to load local file (image to video returns local file)
f, err := os.ReadFile(image.Url)
switch resp := results.(type) {
case worker.ImageResponse:
for i, image := range resp.Images {
buf.Reset()
err := worker.ReadImageB64DataUrl(image.Url, &buf)
if err != nil {
// try to load local file (image to video returns local file)
f, err := os.ReadFile(image.Url)
if err != nil {
return nil, err
}
buf = *bytes.NewBuffer(f)
}

osUrl, err := storage.OS.SaveData(ctx, fileName, bytes.NewBuffer(buf.Bytes()), nil, 0)
if err != nil {
return nil, err
}
buf = *bytes.NewBuffer(f)

resp.Images[i].Url = osUrl
}

results = resp
case worker.AudioResponse:
err := worker.ReadAudioB64DataUrl(resp.Audio.Url, &buf)
if err != nil {
return nil, err
}

osUrl, err := storage.OS.SaveData(ctx, fileName, bytes.NewBuffer(buf.Bytes()), nil, 0)
if err != nil {
return nil, err
}
resp.Audio.Url = osUrl

imgRes.Images[i].Url = osUrl
results = resp
}

return imgRes, nil
//no file response to save, response is text
return results, nil
}

func (n *LivepeerNode) saveRemoteAIWorkerResults(ctx context.Context, results *RemoteAIWorkerResult, requestID string) (*RemoteAIWorkerResult, error) {
if drivers.NodeStorage == nil {
return nil, fmt.Errorf("Missing local storage")
}

// save the file data to node and provide url for download
storage, exists := n.StorageConfigs[requestID]
if !exists {
return nil, errors.New("no storage available for request")
}
// worker.ImageResponse used by ***-to-image and image-to-video require saving binary data for download
// worker.AudioResponse used to text-to-speech also requires saving binary data for download
// other pipelines do not require saving data since they are text responses
imgResp, isImg := results.Results.(worker.ImageResponse)
if isImg {
for idx := range imgResp.Images {
fileName := imgResp.Images[idx].Url
// save the file data to node and provide url for download
storage, exists := n.StorageConfigs[requestID]
if !exists {
return nil, errors.New("no storage available for request")
}
switch resp := results.Results.(type) {
case worker.ImageResponse:
for idx := range resp.Images {
fileName := resp.Images[idx].Url
osUrl, err := storage.OS.SaveData(ctx, fileName, bytes.NewReader(results.Files[fileName]), nil, 0)
if err != nil {
return nil, err
}

imgResp.Images[idx].Url = osUrl
resp.Images[idx].Url = osUrl
delete(results.Files, fileName)
}

// update results for url updates
results.Results = imgResp
results.Results = resp
case worker.AudioResponse:
fileName := resp.Audio.Url
osUrl, err := storage.OS.SaveData(ctx, fileName, bytes.NewReader(results.Files[fileName]), nil, 0)
if err != nil {
return nil, err
}

resp.Audio.Url = osUrl
delete(results.Files, fileName)

results.Results = resp
}

// no file response to save, response is text
return results, nil
}

Expand Down Expand Up @@ -789,6 +816,39 @@ func (orch *orchestrator) ImageToText(ctx context.Context, requestID string, req
return res.Results, nil
}

func (orch *orchestrator) TextToSpeech(ctx context.Context, requestID string, req worker.GenTextToSpeechJSONRequestBody) (interface{}, error) {
// local AIWorker processes job if combined orchestrator/ai worker
if orch.node.AIWorker != nil {
workerResp, err := orch.node.TextToSpeech(ctx, req)
if err == nil {
return orch.node.saveLocalAIWorkerResults(ctx, *workerResp, requestID, "audio/wav")
} else {
clog.Errorf(ctx, "Error processing with local ai worker err=%q", err)
if monitor.Enabled {
monitor.AIResultSaveError(ctx, "text-to-speech", *req.ModelId, string(monitor.SegmentUploadErrorUnknown))
}
return nil, err
}
}

// remote ai worker proceses job
res, err := orch.node.AIWorkerManager.Process(ctx, requestID, "text-to-speech", *req.ModelId, "", AIJobRequestData{Request: req})
if err != nil {
return nil, err
}

res, err = orch.node.saveRemoteAIWorkerResults(ctx, res, requestID)
if err != nil {
clog.Errorf(ctx, "Error saving remote ai result err=%q", err)
if monitor.Enabled {
monitor.AIResultSaveError(ctx, "text-to-speech", *req.ModelId, string(monitor.SegmentUploadErrorUnknown))
}
return nil, err
}

return res.Results, nil
}

// only used for sending work to remote AI worker
func (orch *orchestrator) SaveAIRequestInput(ctx context.Context, requestID string, fileData []byte) (string, error) {
node := orch.node
Expand Down Expand Up @@ -959,6 +1019,10 @@ func (n *LivepeerNode) LLM(ctx context.Context, req worker.GenLLMFormdataRequest
return n.AIWorker.LLM(ctx, req)
}

func (n *LivepeerNode) TextToSpeech(ctx context.Context, req worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error) {
return n.AIWorker.TextToSpeech(ctx, req)
}

// transcodeFrames converts a series of image URLs into a video segment for the image-to-video pipeline.
func (n *LivepeerNode) transcodeFrames(ctx context.Context, sessionID string, urls []string, inProfile ffmpeg.VideoProfile, outProfile ffmpeg.VideoProfile) *TranscodeResult {
ctx = clog.AddOrchSessionID(ctx, sessionID)
Expand Down
14 changes: 12 additions & 2 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ func (h *lphttp) ImageToText() http.Handler {
func (h *lphttp) TextToSpeech() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

Expand Down Expand Up @@ -440,11 +441,11 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
modelID = *v.ModelId

submitFn = func(ctx context.Context) (interface{}, error) {
return orch.TextToSpeech(ctx, v)
return orch.TextToSpeech(ctx, requestID, v)
}

// TTS pricing is typically in characters, including punctuation
words := utf8.RuneCountInString(*v.TextInput)
words := utf8.RuneCountInString(*v.Text)
outPixels = int64(1000 * words)

default:
Expand Down Expand Up @@ -762,6 +763,15 @@ func parseMultiPartResult(body io.Reader, boundary string, pipeline string) core
wkrResult.Err = err
break
}
case "text-to-speech":
var parsedResp worker.AudioResponse
err := json.Unmarshal(body, &parsedResp)
if err != nil {
glog.Error("Error getting results json:", err)
wkrResult.Err = err
break
}
results = parsedResp
}

wkrResult.Results = results
Expand Down
40 changes: 33 additions & 7 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -765,18 +765,46 @@ func CalculateTextToSpeechLatencyScore(took time.Duration, inCharacters int64) f
return took.Seconds() / float64(inCharacters)
}

func processTextToSpeech(ctx context.Context, params aiRequestParams, req worker.GenTextToSpeechJSONRequestBody) (*worker.EncodedFileResponse, error) {
func processTextToSpeech(ctx context.Context, params aiRequestParams, req worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error) {
resp, err := processAIRequest(ctx, params, req)
if err != nil {
return nil, err
}

audioResp := resp.(*worker.EncodedFileResponse)
audioResp, ok := resp.(*worker.AudioResponse)
if !ok {
return nil, errWrongFormat
}

var result []byte
var data bytes.Buffer
var name string
writer := bufio.NewWriter(&data)
err = worker.ReadAudioB64DataUrl(audioResp.Audio.Url, writer)
if err == nil {
// orchestrator sent bae64 encoded result in .Url
name = string(core.RandomManifestID()) + ".wav"
writer.Flush()
result = data.Bytes()
} else {
// orchestrator sent download url, get the data
name = filepath.Base(audioResp.Audio.Url)
result, err = core.DownloadData(ctx, audioResp.Audio.Url)
if err != nil {
return nil, err
}
}

newUrl, err := params.os.SaveData(ctx, name, bytes.NewReader(result), nil, 0)
if err != nil {
return nil, fmt.Errorf("error saving image to objectStore: %w", err)
}

audioResp.Audio.Url = newUrl
return audioResp, nil
}

func submitTextToSpeech(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenTextToSpeechJSONRequestBody) (*worker.EncodedFileResponse, error) {
func submitTextToSpeech(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error) {

client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient))
if err != nil {
Expand All @@ -786,7 +814,7 @@ func submitTextToSpeech(ctx context.Context, params aiRequestParams, sess *AISes
return nil, err
}

textLength := len(*req.TextInput)
textLength := len(*req.Text)
clog.V(common.VERBOSE).Infof(ctx, "Submitting text-to-speech request with text length: %d", textLength)
inCharacters := int64(textLength)
setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, inCharacters)
Expand Down Expand Up @@ -836,7 +864,7 @@ func submitTextToSpeech(ctx context.Context, params aiRequestParams, sess *AISes
monitor.AIRequestFinished(ctx, "text-to-speech", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
}

var res worker.EncodedFileResponse
var res worker.AudioResponse
if err := json.Unmarshal(resp.Body, &res); err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "text-to-speech", *req.ModelId, sess.OrchestratorInfo)
Expand Down Expand Up @@ -1296,7 +1324,6 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitImageToText(ctx, params, sess, v)
}

case worker.GenTextToSpeechJSONRequestBody:
cap = core.Capability_TextToSpeech
modelID = defaultTextToSpeechModelID
Expand All @@ -1306,7 +1333,6 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitTextToSpeech(ctx, params, sess, v)
}

default:
return nil, fmt.Errorf("unsupported request type %T", req)
}
Expand Down
Loading

0 comments on commit f81f92d

Please sign in to comment.