1
1
'use strict'
2
2
3
- const BaseAwsSdkPlugin = require ( '../base' )
4
- const log = require ( '../../../dd-trace/src/log' )
3
+ const log = require ( '../../../../dd-trace/src/log' )
4
+
5
+ const MODEL_TYPE_IDENTIFIERS = [
6
+ 'foundation-model/' ,
7
+ 'custom-model/' ,
8
+ 'provisioned-model/' ,
9
+ 'imported-module/' ,
10
+ 'prompt/' ,
11
+ 'endpoint/' ,
12
+ 'inference-profile/' ,
13
+ 'default-prompt-router/'
14
+ ]
5
15
6
16
const PROVIDER = {
7
17
AI21 : 'AI21' ,
@@ -13,44 +23,6 @@ const PROVIDER = {
13
23
MISTRAL : 'MISTRAL'
14
24
}
15
25
16
- const enabledOperations = [ 'invokeModel' ]
17
-
18
- class BedrockRuntime extends BaseAwsSdkPlugin {
19
- static get id ( ) { return 'bedrock runtime' }
20
-
21
- isEnabled ( request ) {
22
- const operation = request . operation
23
- if ( ! enabledOperations . includes ( operation ) ) {
24
- return false
25
- }
26
-
27
- return super . isEnabled ( request )
28
- }
29
-
30
- generateTags ( params , operation , response ) {
31
- let tags = { }
32
- let modelName = ''
33
- let modelProvider = ''
34
- const modelMeta = params . modelId . split ( '.' )
35
- if ( modelMeta . length === 2 ) {
36
- [ modelProvider , modelName ] = modelMeta
37
- modelProvider = modelProvider . toUpperCase ( )
38
- } else {
39
- [ , modelProvider , modelName ] = modelMeta
40
- modelProvider = modelProvider . toUpperCase ( )
41
- }
42
-
43
- const shouldSetChoiceIds = modelProvider === PROVIDER . COHERE && ! modelName . includes ( 'embed' )
44
-
45
- const requestParams = extractRequestParams ( params , modelProvider )
46
- const textAndResponseReason = extractTextAndResponseReason ( response , modelProvider , modelName , shouldSetChoiceIds )
47
-
48
- tags = buildTagsFromParams ( requestParams , textAndResponseReason , modelProvider , modelName , operation )
49
-
50
- return tags
51
- }
52
- }
53
-
54
26
class Generation {
55
27
constructor ( { message = '' , finishReason = '' , choiceId = '' } = { } ) {
56
28
// stringify message as it could be a single generated message as well as a list of embeddings
@@ -65,18 +37,19 @@ class RequestParams {
65
37
prompt = '' ,
66
38
temperature = undefined ,
67
39
topP = undefined ,
40
+ topK = undefined ,
68
41
maxTokens = undefined ,
69
42
stopSequences = [ ] ,
70
43
inputType = '' ,
71
44
truncate = '' ,
72
45
stream = '' ,
73
46
n = undefined
74
47
} = { } ) {
75
- // TODO: set a truncation limit to prompt
76
48
// stringify prompt as it could be a single prompt as well as a list of message objects
77
49
this . prompt = typeof prompt === 'string' ? prompt : JSON . stringify ( prompt ) || ''
78
50
this . temperature = temperature !== undefined ? temperature : undefined
79
51
this . topP = topP !== undefined ? topP : undefined
52
+ this . topK = topK !== undefined ? topK : undefined
80
53
this . maxTokens = maxTokens !== undefined ? maxTokens : undefined
81
54
this . stopSequences = stopSequences || [ ]
82
55
this . inputType = inputType || ''
@@ -86,11 +59,53 @@ class RequestParams {
86
59
}
87
60
}
88
61
62
+ function parseModelId ( modelId ) {
63
+ // Best effort to extract the model provider and model name from the bedrock model ID.
64
+ // modelId can be a 1/2 period-separated string or a full AWS ARN, based on the following formats:
65
+ // 1. Base model: "{model_provider}.{model_name}"
66
+ // 2. Cross-region model: "{region}.{model_provider}.{model_name}"
67
+ // 3. Other: Prefixed by AWS ARN "arn:aws{+region?}:bedrock:{region}:{account-id}:"
68
+ // a. Foundation model: ARN prefix + "foundation-model/{region?}.{model_provider}.{model_name}"
69
+ // b. Custom model: ARN prefix + "custom-model/{model_provider}.{model_name}"
70
+ // c. Provisioned model: ARN prefix + "provisioned-model/{model-id}"
71
+ // d. Imported model: ARN prefix + "imported-module/{model-id}"
72
+ // e. Prompt management: ARN prefix + "prompt/{prompt-id}"
73
+ // f. Sagemaker: ARN prefix + "endpoint/{model-id}"
74
+ // g. Inference profile: ARN prefix + "{application-?}inference-profile/{model-id}"
75
+ // h. Default prompt router: ARN prefix + "default-prompt-router/{prompt-id}"
76
+ // If model provider cannot be inferred from the modelId formatting, then default to "custom"
77
+ modelId = modelId . toLowerCase ( )
78
+ if ( ! modelId . startsWith ( 'arn:aws' ) ) {
79
+ const modelMeta = modelId . split ( '.' )
80
+ if ( modelMeta . length < 2 ) {
81
+ return { modelProvider : 'custom' , modelName : modelMeta [ 0 ] }
82
+ }
83
+ return { modelProvider : modelMeta [ modelMeta . length - 2 ] , modelName : modelMeta [ modelMeta . length - 1 ] }
84
+ }
85
+
86
+ for ( const identifier of MODEL_TYPE_IDENTIFIERS ) {
87
+ if ( ! modelId . includes ( identifier ) ) {
88
+ continue
89
+ }
90
+ modelId = modelId . split ( identifier ) . pop ( )
91
+ if ( [ 'foundation-model/' , 'custom-model/' ] . includes ( identifier ) ) {
92
+ const modelMeta = modelId . split ( '.' )
93
+ if ( modelMeta . length < 2 ) {
94
+ return { modelProvider : 'custom' , modelName : modelId }
95
+ }
96
+ return { modelProvider : modelMeta [ modelMeta . length - 2 ] , modelName : modelMeta [ modelMeta . length - 1 ] }
97
+ }
98
+ return { modelProvider : 'custom' , modelName : modelId }
99
+ }
100
+
101
+ return { modelProvider : 'custom' , modelName : 'custom' }
102
+ }
103
+
89
104
function extractRequestParams ( params , provider ) {
90
105
const requestBody = JSON . parse ( params . body )
91
106
const modelId = params . modelId
92
107
93
- switch ( provider ) {
108
+ switch ( provider . toUpperCase ( ) ) {
94
109
case PROVIDER . AI21 : {
95
110
let userPrompt = requestBody . prompt
96
111
if ( modelId . includes ( 'jamba' ) ) {
@@ -176,11 +191,11 @@ function extractRequestParams (params, provider) {
176
191
}
177
192
}
178
193
179
- function extractTextAndResponseReason ( response , provider , modelName , shouldSetChoiceIds ) {
194
+ function extractTextAndResponseReason ( response , provider , modelName ) {
180
195
const body = JSON . parse ( Buffer . from ( response . body ) . toString ( 'utf8' ) )
181
-
196
+ const shouldSetChoiceIds = provider . toUpperCase ( ) === PROVIDER . COHERE && ! modelName . includes ( 'embed' )
182
197
try {
183
- switch ( provider ) {
198
+ switch ( provider . toUpperCase ( ) ) {
184
199
case PROVIDER . AI21 : {
185
200
if ( modelName . includes ( 'jamba' ) ) {
186
201
const generations = body . choices || [ ]
@@ -262,34 +277,11 @@ function extractTextAndResponseReason (response, provider, modelName, shouldSetC
262
277
return new Generation ( )
263
278
}
264
279
265
- function buildTagsFromParams ( requestParams , textAndResponseReason , modelProvider , modelName , operation ) {
266
- const tags = { }
267
-
268
- // add request tags
269
- tags [ 'resource.name' ] = operation
270
- tags [ 'aws.bedrock.request.model' ] = modelName
271
- tags [ 'aws.bedrock.request.model_provider' ] = modelProvider
272
- tags [ 'aws.bedrock.request.prompt' ] = requestParams . prompt
273
- tags [ 'aws.bedrock.request.temperature' ] = requestParams . temperature
274
- tags [ 'aws.bedrock.request.top_p' ] = requestParams . topP
275
- tags [ 'aws.bedrock.request.max_tokens' ] = requestParams . maxTokens
276
- tags [ 'aws.bedrock.request.stop_sequences' ] = requestParams . stopSequences
277
- tags [ 'aws.bedrock.request.input_type' ] = requestParams . inputType
278
- tags [ 'aws.bedrock.request.truncate' ] = requestParams . truncate
279
- tags [ 'aws.bedrock.request.stream' ] = requestParams . stream
280
- tags [ 'aws.bedrock.request.n' ] = requestParams . n
281
-
282
- // add response tags
283
- if ( modelName . includes ( 'embed' ) ) {
284
- tags [ 'aws.bedrock.response.embedding_length' ] = textAndResponseReason . message . length
285
- }
286
- if ( textAndResponseReason . choiceId ) {
287
- tags [ 'aws.bedrock.response.choices.id' ] = textAndResponseReason . choiceId
288
- }
289
- tags [ 'aws.bedrock.response.choices.text' ] = textAndResponseReason . message
290
- tags [ 'aws.bedrock.response.choices.finish_reason' ] = textAndResponseReason . finishReason
291
-
292
- return tags
280
+ module . exports = {
281
+ Generation,
282
+ RequestParams,
283
+ parseModelId,
284
+ extractRequestParams,
285
+ extractTextAndResponseReason,
286
+ PROVIDER
293
287
}
294
-
295
- module . exports = BedrockRuntime
0 commit comments