Skip to content

Commit 25f089a

Browse files
sabrennerwatson
authored andcommitted
[MLOB-2098] feat(llmobs): record bedrock token counts (#5152)
* working version by patching deserializedr * wip * cleanup * use tokens on response directly if available * make it run on all command types if available * make token extraction cleaner * test output * parseint headers * remove comment * rename channel * cleanup * simpler instance patching * fmt * Update packages/datadog-instrumentations/src/helpers/hooks.js * check subscribers
1 parent 2f816e4 commit 25f089a

File tree

5 files changed

+150
-35
lines changed

5 files changed

+150
-35
lines changed

packages/datadog-instrumentations/src/aws-sdk.js

+16
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,18 @@ function wrapRequest (send) {
4040
}
4141
}
4242

43+
function wrapDeserialize (deserialize, channelSuffix) {
44+
const headersCh = channel(`apm:aws:response:deserialize:${channelSuffix}`)
45+
46+
return function (response) {
47+
if (headersCh.hasSubscribers) {
48+
headersCh.publish({ headers: response.headers })
49+
}
50+
51+
return deserialize.apply(this, arguments)
52+
}
53+
}
54+
4355
function wrapSmithySend (send) {
4456
return function (command, ...args) {
4557
const cb = args[args.length - 1]
@@ -61,6 +73,10 @@ function wrapSmithySend (send) {
6173
const responseStartChannel = channel(`apm:aws:response:start:${channelSuffix}`)
6274
const responseFinishChannel = channel(`apm:aws:response:finish:${channelSuffix}`)
6375

76+
if (typeof command.deserialize === 'function') {
77+
shimmer.wrap(command, 'deserialize', deserialize => wrapDeserialize(deserialize, channelSuffix))
78+
}
79+
6480
return innerAr.runInAsyncScope(() => {
6581
startCh.publish({
6682
serviceIdentifier,

packages/datadog-plugin-aws-sdk/src/services/bedrockruntime/utils.js

+33-6
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,23 @@ const PROVIDER = {
2424
}
2525

2626
class Generation {
27-
constructor ({ message = '', finishReason = '', choiceId = '' } = {}) {
27+
constructor ({
28+
message = '',
29+
finishReason = '',
30+
choiceId = '',
31+
role,
32+
inputTokens,
33+
outputTokens
34+
} = {}) {
2835
// stringify message as it could be a single generated message as well as a list of embeddings
2936
this.message = typeof message === 'string' ? message : JSON.stringify(message) || ''
3037
this.finishReason = finishReason || ''
3138
this.choiceId = choiceId || undefined
39+
this.role = role
40+
this.usage = {
41+
inputTokens,
42+
outputTokens
43+
}
3244
}
3345
}
3446

@@ -202,9 +214,12 @@ function extractTextAndResponseReason (response, provider, modelName) {
202214
if (generations.length > 0) {
203215
const generation = generations[0]
204216
return new Generation({
205-
message: generation.message,
217+
message: generation.message.content,
206218
finishReason: generation.finish_reason,
207-
choiceId: shouldSetChoiceIds ? generation.id : undefined
219+
choiceId: shouldSetChoiceIds ? generation.id : undefined,
220+
role: generation.message.role,
221+
inputTokens: body.usage?.prompt_tokens,
222+
outputTokens: body.usage?.completion_tokens
208223
})
209224
}
210225
}
@@ -214,7 +229,9 @@ function extractTextAndResponseReason (response, provider, modelName) {
214229
return new Generation({
215230
message: completion.data?.text,
216231
finishReason: completion?.finishReason,
217-
choiceId: shouldSetChoiceIds ? completion?.id : undefined
232+
choiceId: shouldSetChoiceIds ? completion?.id : undefined,
233+
inputTokens: body.usage?.prompt_tokens,
234+
outputTokens: body.usage?.completion_tokens
218235
})
219236
}
220237
return new Generation()
@@ -226,7 +243,12 @@ function extractTextAndResponseReason (response, provider, modelName) {
226243
const results = body.results || []
227244
if (results.length > 0) {
228245
const result = results[0]
229-
return new Generation({ message: result.outputText, finishReason: result.completionReason })
246+
return new Generation({
247+
message: result.outputText,
248+
finishReason: result.completionReason,
249+
inputTokens: body.inputTextTokenCount,
250+
outputTokens: result.tokenCount
251+
})
230252
}
231253
break
232254
}
@@ -252,7 +274,12 @@ function extractTextAndResponseReason (response, provider, modelName) {
252274
break
253275
}
254276
case PROVIDER.META: {
255-
return new Generation({ message: body.generation, finishReason: body.stop_reason })
277+
return new Generation({
278+
message: body.generation,
279+
finishReason: body.stop_reason,
280+
inputTokens: body.prompt_token_count,
281+
outputTokens: body.generation_token_count
282+
})
256283
}
257284
case PROVIDER.MISTRAL: {
258285
const mistralGenerations = body.outputs || []

packages/datadog-plugin-aws-sdk/test/fixtures/bedrockruntime.js

+37-19
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,22 @@ bedrockruntime.models = [
3232
},
3333
response: {
3434
inputTextTokenCount: 7,
35-
results: {
36-
inputTextTokenCount: 7,
37-
results: [
38-
{
39-
tokenCount: 35,
40-
outputText: '\n' +
41-
'Paris is the capital of France. France is a country that is located in Western Europe. ' +
42-
'Paris is one of the most populous cities in the European Union. ',
43-
completionReason: 'FINISH'
44-
}
45-
]
46-
}
47-
}
35+
results: [{
36+
tokenCount: 35,
37+
outputText: '\n' +
38+
'Paris is the capital of France. France is a country that is located in Western Europe. ' +
39+
'Paris is one of the most populous cities in the European Union. ',
40+
completionReason: 'FINISH'
41+
}]
42+
},
43+
usage: {
44+
inputTokens: 7,
45+
outputTokens: 35,
46+
totalTokens: 42
47+
},
48+
output: '\n' +
49+
'Paris is the capital of France. France is a country that is located in Western Europe. ' +
50+
'Paris is one of the most populous cities in the European Union. '
4851
},
4952
{
5053
provider: PROVIDER.AI21,
@@ -79,7 +82,14 @@ bedrockruntime.models = [
7982
completion_tokens: 7,
8083
total_tokens: 17
8184
}
82-
}
85+
},
86+
usage: {
87+
inputTokens: 10,
88+
outputTokens: 7,
89+
totalTokens: 17
90+
},
91+
output: 'The capital of France is Paris.',
92+
outputRole: 'assistant'
8393
},
8494
{
8595
provider: PROVIDER.ANTHROPIC,
@@ -97,7 +107,8 @@ bedrockruntime.models = [
97107
completion: ' Paris is the capital of France.',
98108
stop_reason: 'stop_sequence',
99109
stop: '\n\nHuman:'
100-
}
110+
},
111+
output: ' Paris is the capital of France.'
101112
},
102113
{
103114
provider: PROVIDER.COHERE,
@@ -120,8 +131,8 @@ bedrockruntime.models = [
120131
}
121132
],
122133
prompt: 'What is the capital of France?'
123-
}
124-
134+
},
135+
output: ' The capital of France is Paris. \n'
125136
},
126137
{
127138
provider: PROVIDER.META,
@@ -138,7 +149,13 @@ bedrockruntime.models = [
138149
prompt_token_count: 10,
139150
generation_token_count: 7,
140151
stop_reason: 'stop'
141-
}
152+
},
153+
usage: {
154+
inputTokens: 10,
155+
outputTokens: 7,
156+
totalTokens: 17
157+
},
158+
output: '\n\nThe capital of France is Paris.'
142159
},
143160
{
144161
provider: PROVIDER.MISTRAL,
@@ -158,7 +175,8 @@ bedrockruntime.models = [
158175
stop_reason: 'stop'
159176
}
160177
]
161-
}
178+
},
179+
output: 'The capital of France is Paris.'
162180
}
163181
]
164182
bedrockruntime.modelConfig = {

packages/dd-trace/src/llmobs/plugins/bedrockruntime.js

+48-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ const {
88
parseModelId
99
} = require('../../../../datadog-plugin-aws-sdk/src/services/bedrockruntime/utils')
1010

11-
const enabledOperations = ['invokeModel']
11+
const ENABLED_OPERATIONS = ['invokeModel']
12+
13+
const requestIdsToTokens = {}
1214

1315
class BedrockRuntimeLLMObsPlugin extends BaseLLMObsPlugin {
1416
constructor () {
@@ -18,7 +20,7 @@ class BedrockRuntimeLLMObsPlugin extends BaseLLMObsPlugin {
1820
const request = response.request
1921
const operation = request.operation
2022
// avoids instrumenting other non supported runtime operations
21-
if (!enabledOperations.includes(operation)) {
23+
if (!ENABLED_OPERATIONS.includes(operation)) {
2224
return
2325
}
2426
const { modelProvider, modelName } = parseModelId(request.params.modelId)
@@ -30,6 +32,17 @@ class BedrockRuntimeLLMObsPlugin extends BaseLLMObsPlugin {
3032
const span = storage.getStore()?.span
3133
this.setLLMObsTags({ request, span, response, modelProvider, modelName })
3234
})
35+
36+
this.addSub('apm:aws:response:deserialize:bedrockruntime', ({ headers }) => {
37+
const requestId = headers['x-amzn-requestid']
38+
const inputTokenCount = headers['x-amzn-bedrock-input-token-count']
39+
const outputTokenCount = headers['x-amzn-bedrock-output-token-count']
40+
41+
requestIdsToTokens[requestId] = {
42+
inputTokensFromHeaders: inputTokenCount && parseInt(inputTokenCount),
43+
outputTokensFromHeaders: outputTokenCount && parseInt(outputTokenCount)
44+
}
45+
})
3346
}
3447

3548
setLLMObsTags ({ request, span, response, modelProvider, modelName }) {
@@ -52,7 +65,39 @@ class BedrockRuntimeLLMObsPlugin extends BaseLLMObsPlugin {
5265
})
5366

5467
// add I/O tags
55-
this._tagger.tagLLMIO(span, requestParams.prompt, textAndResponseReason.message)
68+
this._tagger.tagLLMIO(
69+
span,
70+
requestParams.prompt,
71+
[{ content: textAndResponseReason.message, role: textAndResponseReason.role }]
72+
)
73+
74+
// add token metrics
75+
const { inputTokens, outputTokens, totalTokens } = extractTokens({
76+
requestId: response.$metadata.requestId,
77+
usage: textAndResponseReason.usage
78+
})
79+
this._tagger.tagMetrics(span, {
80+
inputTokens,
81+
outputTokens,
82+
totalTokens
83+
})
84+
}
85+
}
86+
87+
function extractTokens ({ requestId, usage }) {
88+
const {
89+
inputTokensFromHeaders,
90+
outputTokensFromHeaders
91+
} = requestIdsToTokens[requestId] || {}
92+
delete requestIdsToTokens[requestId]
93+
94+
const inputTokens = usage.inputTokens || inputTokensFromHeaders || 0
95+
const outputTokens = usage.outputTokens || outputTokensFromHeaders || 0
96+
97+
return {
98+
inputTokens,
99+
outputTokens,
100+
totalTokens: inputTokens + outputTokens
56101
}
57102
}
58103

packages/dd-trace/test/llmobs/plugins/aws-sdk/bedrockruntime.spec.js

+16-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
const agent = require('../../../plugins/agent')
44

55
const nock = require('nock')
6-
const { expectedLLMObsLLMSpanEvent, deepEqualWithMockValues, MOCK_ANY } = require('../../util')
6+
const { expectedLLMObsLLMSpanEvent, deepEqualWithMockValues } = require('../../util')
77
const { models, modelConfig } = require('../../../../../datadog-plugin-aws-sdk/test/fixtures/bedrockruntime')
88
const chai = require('chai')
99
const LLMObsAgentProxySpanWriter = require('../../../../src/llmobs/writers/spans/agentProxy')
@@ -78,22 +78,31 @@ describe('Plugin', () => {
7878

7979
nock('http://127.0.0.1:4566')
8080
.post(`/model/${model.modelId}/invoke`)
81-
.reply(200, response)
81+
.reply(200, response, {
82+
'x-amzn-bedrock-input-token-count': 50,
83+
'x-amzn-bedrock-output-token-count': 70,
84+
'x-amzn-requestid': Date.now().toString()
85+
})
8286

8387
const command = new AWS.InvokeModelCommand(request)
8488

89+
const expectedOutput = { content: model.output }
90+
if (model.outputRole) expectedOutput.role = model.outputRole
91+
8592
agent.use(traces => {
8693
const span = traces[0][0]
8794
const spanEvent = LLMObsAgentProxySpanWriter.prototype.append.getCall(0).args[0]
8895
const expected = expectedLLMObsLLMSpanEvent({
8996
span,
9097
spanKind: 'llm',
9198
name: 'bedrock-runtime.command',
92-
inputMessages: [
93-
{ content: model.userPrompt }
94-
],
95-
outputMessages: MOCK_ANY,
96-
tokenMetrics: { input_tokens: 0, output_tokens: 0, total_tokens: 0 },
99+
inputMessages: [{ content: model.userPrompt }],
100+
outputMessages: [expectedOutput],
101+
tokenMetrics: {
102+
input_tokens: model.usage?.inputTokens ?? 50,
103+
output_tokens: model.usage?.outputTokens ?? 70,
104+
total_tokens: model.usage?.totalTokens ?? 120
105+
},
97106
modelName: model.modelId.split('.')[1].toLowerCase(),
98107
modelProvider: model.provider.toLowerCase(),
99108
metadata: {

0 commit comments

Comments
 (0)