Skip to content

Commit f41f5f7

Browse files
yahya-moumansabrennertlhunter
authored
(chore)APM: Refactor Bedrock Integration (#5137)
* refactor apm tracing * Update packages/datadog-plugin-aws-sdk/src/services/bedrockruntime/index.js Co-authored-by: Sam Brenner <106700075+sabrenner@users.noreply.github.com> * Update packages/datadog-plugin-aws-sdk/src/services/bedrockruntime/tracing.js Co-authored-by: Sam Brenner <106700075+sabrenner@users.noreply.github.com> * Update packages/datadog-plugin-aws-sdk/src/services/bedrockruntime/tracing.js Co-authored-by: Sam Brenner <106700075+sabrenner@users.noreply.github.com> * Update packages/datadog-plugin-aws-sdk/src/services/bedrockruntime/utils.js Co-authored-by: Sam Brenner <106700075+sabrenner@users.noreply.github.com> * CODEOWNERS * remove shouldSetChoiceId override * remove shouldSetChoiceId override * lint * Update packages/datadog-instrumentations/src/aws-sdk.js Co-authored-by: Thomas Hunter II <tlhunter@datadog.com> --------- Co-authored-by: Sam Brenner <106700075+sabrenner@users.noreply.github.com> Co-authored-by: Thomas Hunter II <tlhunter@datadog.com>
1 parent 30efc06 commit f41f5f7

File tree

6 files changed

+151
-78
lines changed

6 files changed

+151
-78
lines changed

CODEOWNERS

+2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
/packages/datadog-plugin-langchain/ @DataDog/ml-observability
6060
/packages/datadog-instrumentations/src/openai.js @DataDog/ml-observability
6161
/packages/datadog-instrumentations/src/langchain.js @DataDog/ml-observability
62+
/packages/datadog-plugin-aws-sdk/src/services/bedrockruntime @DataDog/ml-observability
63+
/packages/datadog-plugin-aws-sdk/test/bedrockruntime.spec.js @DataDog/ml-observability
6264

6365
# CI
6466
/.github/workflows/appsec.yml @DataDog/asm-js

packages/datadog-instrumentations/src/aws-sdk.js

+3-1
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ function getMessage (request, error, result) {
155155
}
156156

157157
function getChannelSuffix (name) {
158+
// some resource identifiers have spaces between ex: bedrock runtime
159+
name = name.replaceAll(' ', '')
158160
return [
159161
'cloudwatchlogs',
160162
'dynamodb',
@@ -168,7 +170,7 @@ function getChannelSuffix (name) {
168170
'sqs',
169171
'states',
170172
'stepfunctions',
171-
'bedrock runtime'
173+
'bedrockruntime'
172174
].includes(name)
173175
? name
174176
: 'default'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
const CompositePlugin = require('../../../../dd-trace/src/plugins/composite')
2+
const BedrockRuntimeTracing = require('./tracing')
3+
class BedrockRuntimePlugin extends CompositePlugin {
4+
static get id () {
5+
return 'bedrockruntime'
6+
}
7+
8+
static get plugins () {
9+
return {
10+
tracing: BedrockRuntimeTracing
11+
}
12+
}
13+
}
14+
module.exports = BedrockRuntimePlugin
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
'use strict'
2+
3+
const BaseAwsSdkPlugin = require('../../base')
4+
const { parseModelId, extractRequestParams, extractTextAndResponseReason } = require('./utils')
5+
6+
const enabledOperations = ['invokeModel']
7+
8+
class BedrockRuntime extends BaseAwsSdkPlugin {
9+
static get id () { return 'bedrockruntime' }
10+
11+
isEnabled (request) {
12+
const operation = request.operation
13+
if (!enabledOperations.includes(operation)) {
14+
return false
15+
}
16+
17+
return super.isEnabled(request)
18+
}
19+
20+
generateTags (params, operation, response) {
21+
const { modelProvider, modelName } = parseModelId(params.modelId)
22+
23+
const requestParams = extractRequestParams(params, modelProvider)
24+
const textAndResponseReason = extractTextAndResponseReason(response, modelProvider, modelName)
25+
26+
const tags = buildTagsFromParams(requestParams, textAndResponseReason, modelProvider, modelName, operation)
27+
28+
return tags
29+
}
30+
}
31+
32+
function buildTagsFromParams (requestParams, textAndResponseReason, modelProvider, modelName, operation) {
33+
const tags = {}
34+
35+
// add request tags
36+
tags['resource.name'] = operation
37+
tags['aws.bedrock.request.model'] = modelName
38+
tags['aws.bedrock.request.model_provider'] = modelProvider.toLowerCase()
39+
tags['aws.bedrock.request.prompt'] = requestParams.prompt
40+
tags['aws.bedrock.request.temperature'] = requestParams.temperature
41+
tags['aws.bedrock.request.top_p'] = requestParams.topP
42+
tags['aws.bedrock.request.top_k'] = requestParams.topK
43+
tags['aws.bedrock.request.max_tokens'] = requestParams.maxTokens
44+
tags['aws.bedrock.request.stop_sequences'] = requestParams.stopSequences
45+
tags['aws.bedrock.request.input_type'] = requestParams.inputType
46+
tags['aws.bedrock.request.truncate'] = requestParams.truncate
47+
tags['aws.bedrock.request.stream'] = requestParams.stream
48+
tags['aws.bedrock.request.n'] = requestParams.n
49+
50+
// add response tags
51+
if (modelName.includes('embed')) {
52+
tags['aws.bedrock.response.embedding_length'] = textAndResponseReason.message.length
53+
}
54+
if (textAndResponseReason.choiceId) {
55+
tags['aws.bedrock.response.choices.id'] = textAndResponseReason.choiceId
56+
}
57+
tags['aws.bedrock.response.choices.text'] = textAndResponseReason.message
58+
tags['aws.bedrock.response.choices.finish_reason'] = textAndResponseReason.finishReason
59+
60+
return tags
61+
}
62+
63+
module.exports = BedrockRuntime

packages/datadog-plugin-aws-sdk/src/services/bedrockruntime.js packages/datadog-plugin-aws-sdk/src/services/bedrockruntime/utils.js

+67-75
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
'use strict'
22

3-
const BaseAwsSdkPlugin = require('../base')
4-
const log = require('../../../dd-trace/src/log')
3+
const log = require('../../../../dd-trace/src/log')
4+
5+
const MODEL_TYPE_IDENTIFIERS = [
6+
'foundation-model/',
7+
'custom-model/',
8+
'provisioned-model/',
9+
'imported-module/',
10+
'prompt/',
11+
'endpoint/',
12+
'inference-profile/',
13+
'default-prompt-router/'
14+
]
515

616
const PROVIDER = {
717
AI21: 'AI21',
@@ -13,44 +23,6 @@ const PROVIDER = {
1323
MISTRAL: 'MISTRAL'
1424
}
1525

16-
const enabledOperations = ['invokeModel']
17-
18-
class BedrockRuntime extends BaseAwsSdkPlugin {
19-
static get id () { return 'bedrock runtime' }
20-
21-
isEnabled (request) {
22-
const operation = request.operation
23-
if (!enabledOperations.includes(operation)) {
24-
return false
25-
}
26-
27-
return super.isEnabled(request)
28-
}
29-
30-
generateTags (params, operation, response) {
31-
let tags = {}
32-
let modelName = ''
33-
let modelProvider = ''
34-
const modelMeta = params.modelId.split('.')
35-
if (modelMeta.length === 2) {
36-
[modelProvider, modelName] = modelMeta
37-
modelProvider = modelProvider.toUpperCase()
38-
} else {
39-
[, modelProvider, modelName] = modelMeta
40-
modelProvider = modelProvider.toUpperCase()
41-
}
42-
43-
const shouldSetChoiceIds = modelProvider === PROVIDER.COHERE && !modelName.includes('embed')
44-
45-
const requestParams = extractRequestParams(params, modelProvider)
46-
const textAndResponseReason = extractTextAndResponseReason(response, modelProvider, modelName, shouldSetChoiceIds)
47-
48-
tags = buildTagsFromParams(requestParams, textAndResponseReason, modelProvider, modelName, operation)
49-
50-
return tags
51-
}
52-
}
53-
5426
class Generation {
5527
constructor ({ message = '', finishReason = '', choiceId = '' } = {}) {
5628
// stringify message as it could be a single generated message as well as a list of embeddings
@@ -65,18 +37,19 @@ class RequestParams {
6537
prompt = '',
6638
temperature = undefined,
6739
topP = undefined,
40+
topK = undefined,
6841
maxTokens = undefined,
6942
stopSequences = [],
7043
inputType = '',
7144
truncate = '',
7245
stream = '',
7346
n = undefined
7447
} = {}) {
75-
// TODO: set a truncation limit to prompt
7648
// stringify prompt as it could be a single prompt as well as a list of message objects
7749
this.prompt = typeof prompt === 'string' ? prompt : JSON.stringify(prompt) || ''
7850
this.temperature = temperature !== undefined ? temperature : undefined
7951
this.topP = topP !== undefined ? topP : undefined
52+
this.topK = topK !== undefined ? topK : undefined
8053
this.maxTokens = maxTokens !== undefined ? maxTokens : undefined
8154
this.stopSequences = stopSequences || []
8255
this.inputType = inputType || ''
@@ -86,11 +59,53 @@ class RequestParams {
8659
}
8760
}
8861

62+
function parseModelId (modelId) {
63+
// Best effort to extract the model provider and model name from the bedrock model ID.
64+
// modelId can be a 1/2 period-separated string or a full AWS ARN, based on the following formats:
65+
// 1. Base model: "{model_provider}.{model_name}"
66+
// 2. Cross-region model: "{region}.{model_provider}.{model_name}"
67+
// 3. Other: Prefixed by AWS ARN "arn:aws{+region?}:bedrock:{region}:{account-id}:"
68+
// a. Foundation model: ARN prefix + "foundation-model/{region?}.{model_provider}.{model_name}"
69+
// b. Custom model: ARN prefix + "custom-model/{model_provider}.{model_name}"
70+
// c. Provisioned model: ARN prefix + "provisioned-model/{model-id}"
71+
// d. Imported model: ARN prefix + "imported-module/{model-id}"
72+
// e. Prompt management: ARN prefix + "prompt/{prompt-id}"
73+
// f. Sagemaker: ARN prefix + "endpoint/{model-id}"
74+
// g. Inference profile: ARN prefix + "{application-?}inference-profile/{model-id}"
75+
// h. Default prompt router: ARN prefix + "default-prompt-router/{prompt-id}"
76+
// If model provider cannot be inferred from the modelId formatting, then default to "custom"
77+
modelId = modelId.toLowerCase()
78+
if (!modelId.startsWith('arn:aws')) {
79+
const modelMeta = modelId.split('.')
80+
if (modelMeta.length < 2) {
81+
return { modelProvider: 'custom', modelName: modelMeta[0] }
82+
}
83+
return { modelProvider: modelMeta[modelMeta.length - 2], modelName: modelMeta[modelMeta.length - 1] }
84+
}
85+
86+
for (const identifier of MODEL_TYPE_IDENTIFIERS) {
87+
if (!modelId.includes(identifier)) {
88+
continue
89+
}
90+
modelId = modelId.split(identifier).pop()
91+
if (['foundation-model/', 'custom-model/'].includes(identifier)) {
92+
const modelMeta = modelId.split('.')
93+
if (modelMeta.length < 2) {
94+
return { modelProvider: 'custom', modelName: modelId }
95+
}
96+
return { modelProvider: modelMeta[modelMeta.length - 2], modelName: modelMeta[modelMeta.length - 1] }
97+
}
98+
return { modelProvider: 'custom', modelName: modelId }
99+
}
100+
101+
return { modelProvider: 'custom', modelName: 'custom' }
102+
}
103+
89104
function extractRequestParams (params, provider) {
90105
const requestBody = JSON.parse(params.body)
91106
const modelId = params.modelId
92107

93-
switch (provider) {
108+
switch (provider.toUpperCase()) {
94109
case PROVIDER.AI21: {
95110
let userPrompt = requestBody.prompt
96111
if (modelId.includes('jamba')) {
@@ -176,11 +191,11 @@ function extractRequestParams (params, provider) {
176191
}
177192
}
178193

179-
function extractTextAndResponseReason (response, provider, modelName, shouldSetChoiceIds) {
194+
function extractTextAndResponseReason (response, provider, modelName) {
180195
const body = JSON.parse(Buffer.from(response.body).toString('utf8'))
181-
196+
const shouldSetChoiceIds = provider.toUpperCase() === PROVIDER.COHERE && !modelName.includes('embed')
182197
try {
183-
switch (provider) {
198+
switch (provider.toUpperCase()) {
184199
case PROVIDER.AI21: {
185200
if (modelName.includes('jamba')) {
186201
const generations = body.choices || []
@@ -262,34 +277,11 @@ function extractTextAndResponseReason (response, provider, modelName, shouldSetC
262277
return new Generation()
263278
}
264279

265-
function buildTagsFromParams (requestParams, textAndResponseReason, modelProvider, modelName, operation) {
266-
const tags = {}
267-
268-
// add request tags
269-
tags['resource.name'] = operation
270-
tags['aws.bedrock.request.model'] = modelName
271-
tags['aws.bedrock.request.model_provider'] = modelProvider
272-
tags['aws.bedrock.request.prompt'] = requestParams.prompt
273-
tags['aws.bedrock.request.temperature'] = requestParams.temperature
274-
tags['aws.bedrock.request.top_p'] = requestParams.topP
275-
tags['aws.bedrock.request.max_tokens'] = requestParams.maxTokens
276-
tags['aws.bedrock.request.stop_sequences'] = requestParams.stopSequences
277-
tags['aws.bedrock.request.input_type'] = requestParams.inputType
278-
tags['aws.bedrock.request.truncate'] = requestParams.truncate
279-
tags['aws.bedrock.request.stream'] = requestParams.stream
280-
tags['aws.bedrock.request.n'] = requestParams.n
281-
282-
// add response tags
283-
if (modelName.includes('embed')) {
284-
tags['aws.bedrock.response.embedding_length'] = textAndResponseReason.message.length
285-
}
286-
if (textAndResponseReason.choiceId) {
287-
tags['aws.bedrock.response.choices.id'] = textAndResponseReason.choiceId
288-
}
289-
tags['aws.bedrock.response.choices.text'] = textAndResponseReason.message
290-
tags['aws.bedrock.response.choices.finish_reason'] = textAndResponseReason.finishReason
291-
292-
return tags
280+
module.exports = {
281+
Generation,
282+
RequestParams,
283+
parseModelId,
284+
extractRequestParams,
285+
extractTextAndResponseReason,
286+
PROVIDER
293287
}
294-
295-
module.exports = BedrockRuntime

packages/datadog-plugin-aws-sdk/test/bedrock.spec.js packages/datadog-plugin-aws-sdk/test/bedrockruntime.spec.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ const PROVIDER = {
1616
}
1717

1818
describe('Plugin', () => {
19-
describe('aws-sdk (bedrock)', function () {
19+
describe('aws-sdk (bedrockruntime)', function () {
2020
setup()
2121

2222
withVersions('aws-sdk', ['@aws-sdk/smithy-client', 'aws-sdk'], '>=3', (version, moduleName) => {
@@ -217,7 +217,7 @@ describe('Plugin', () => {
217217
expect(span.meta).to.include({
218218
'aws.operation': 'invokeModel',
219219
'aws.bedrock.request.model': model.modelId.split('.')[1],
220-
'aws.bedrock.request.model_provider': model.provider,
220+
'aws.bedrock.request.model_provider': model.provider.toLowerCase(),
221221
'aws.bedrock.request.prompt': model.userPrompt
222222
})
223223
expect(span.metrics).to.include({

0 commit comments

Comments
 (0)