17
17
18
18
import dataclasses
19
19
import uuid
20
+ from typing import Any
20
21
from typing import Callable
21
22
from typing import Dict
22
23
from typing import Iterable
24
+ from typing import List
23
25
from typing import Optional
24
26
from typing import Tuple
25
27
from typing import TypeVar
33
35
from apache_beam .ml .anomaly .base import AnomalyResult
34
36
from apache_beam .ml .anomaly .base import EnsembleAnomalyDetector
35
37
from apache_beam .ml .anomaly .base import ThresholdFn
38
+ from apache_beam .ml .anomaly .detectors .offline import OfflineDetector
36
39
from apache_beam .ml .anomaly .specifiable import Spec
37
40
from apache_beam .ml .anomaly .specifiable import Specifiable
41
+ from apache_beam .ml .inference .base import RunInference
38
42
from apache_beam .transforms .userstate import ReadModifyWriteStateSpec
39
43
40
44
KeyT = TypeVar ('KeyT' )
@@ -97,9 +101,11 @@ def process(
97
101
yield k1 , (k2 ,
98
102
AnomalyResult (
99
103
example = data ,
100
- predictions = [AnomalyPrediction (
101
- model_id = self ._underlying ._model_id ,
102
- score = self .score_and_learn (data ))]))
104
+ predictions = [
105
+ AnomalyPrediction (
106
+ model_id = self ._underlying ._model_id ,
107
+ score = self .score_and_learn (data ))
108
+ ]))
103
109
104
110
model_state .write (self ._underlying )
105
111
@@ -325,7 +331,8 @@ def expand(
325
331
if self ._aggregation_fn is None :
326
332
# simply put predictions into an iterable (list)
327
333
ret = (
328
- post_gbk | beam .MapTuple (
334
+ post_gbk
335
+ | beam .MapTuple (
329
336
lambda k ,
330
337
v : (
331
338
k [0 ],
@@ -353,7 +360,8 @@ def expand(
353
360
# We use (original_key, temp_key) as the key for GroupByKey() so that
354
361
# scores from multiple detectors per data point are grouped.
355
362
ret = (
356
- post_gbk | beam .MapTuple (
363
+ post_gbk
364
+ | beam .MapTuple (
357
365
lambda k ,
358
366
v ,
359
367
agg = aggregation_fn : (
@@ -406,6 +414,76 @@ def expand(
406
414
return ret
407
415
408
416
417
+ class RunOfflineDetector (beam .PTransform [beam .PCollection [KeyedInputT ],
418
+ beam .PCollection [KeyedOutputT ]]):
419
+ """Runs a offline anomaly detector on a PCollection of data.
420
+
421
+ This PTransform applies a `OfflineDetector` to the input data, handling
422
+ custom input/output conversion and inference.
423
+
424
+ Args:
425
+ offline_detector: The `OfflineDetector` to run.
426
+ """
427
+ def __init__ (self , offline_detector : OfflineDetector ):
428
+ self ._offline_detector = offline_detector
429
+
430
+ def unnest_and_convert (
431
+ self , nested : Tuple [Tuple [Any , Any ], dict [str , List ]]) -> KeyedOutputT :
432
+ """Unnests and converts the model output to AnomalyResult.
433
+
434
+ Args:
435
+ nested: A tuple containing the combined key (origin key, temp key) and
436
+ a dictionary of input and output from RunInference.
437
+
438
+ Returns:
439
+ A tuple containing the original key and AnomalyResult.
440
+ """
441
+ key , value_dict = nested
442
+ score = value_dict ['output' ][0 ]
443
+ result = AnomalyResult (
444
+ example = value_dict ['input' ][0 ],
445
+ predictions = [
446
+ AnomalyPrediction (
447
+ model_id = self ._offline_detector ._model_id , score = score )
448
+ ])
449
+ return key [0 ], (key [1 ], result )
450
+
451
+ def expand (
452
+ self ,
453
+ input : beam .PCollection [KeyedInputT ]) -> beam .PCollection [KeyedOutputT ]:
454
+ model_uuid = f"{ self ._offline_detector ._model_id } :{ uuid .uuid4 ().hex [:6 ]} "
455
+
456
+ # Call RunInference Transform with the keyed model handler
457
+ run_inference = RunInference (
458
+ self ._offline_detector ._keyed_model_handler ,
459
+ ** self ._offline_detector ._run_inference_args )
460
+
461
+ # ((orig_key, temp_key), beam.Row)
462
+ rekeyed_model_input = input | "Rekey" >> beam .Map (
463
+ lambda x : ((x [0 ], x [1 ][0 ]), x [1 ][1 ]))
464
+
465
+ # ((orig_key, temp_key), float)
466
+ rekeyed_model_output = (
467
+ rekeyed_model_input
468
+ | f"Call RunInference ({ model_uuid } )" >> run_inference )
469
+
470
+ # ((orig_key, temp_key), {'input':[row], 'output:[float]})
471
+ rekeyed_cogbk = {
472
+ 'input' : rekeyed_model_input , 'output' : rekeyed_model_output
473
+ } | beam .CoGroupByKey ()
474
+
475
+ ret = (
476
+ rekeyed_cogbk |
477
+ "Unnest and convert model output" >> beam .Map (self .unnest_and_convert ))
478
+
479
+ if self ._offline_detector ._threshold_criterion :
480
+ ret = (
481
+ ret | f"Run Threshold Criterion ({ model_uuid } )" >>
482
+ RunThresholdCriterion (self ._offline_detector ._threshold_criterion ))
483
+
484
+ return ret
485
+
486
+
409
487
class RunEnsembleDetector (beam .PTransform [beam .PCollection [KeyedInputT ],
410
488
beam .PCollection [KeyedOutputT ]]):
411
489
"""Runs an ensemble of anomaly detectors on a PCollection of data.
@@ -432,8 +510,14 @@ def expand(
432
510
for idx , detector in enumerate (self ._ensemble_detector ._sub_detectors ):
433
511
if isinstance (detector , EnsembleAnomalyDetector ):
434
512
results .append (
435
- input | f"Run Ensemble Detector at index { idx } ({ model_uuid } )" >>
513
+ input
514
+ | f"Run Ensemble Detector at index { idx } ({ model_uuid } )" >>
436
515
RunEnsembleDetector (detector ))
516
+ elif isinstance (detector , OfflineDetector ):
517
+ results .append (
518
+ input
519
+ | f"Run Offline Detector at index { idx } ({ model_uuid } )" >>
520
+ RunOfflineDetector (detector ))
437
521
else :
438
522
results .append (
439
523
input
@@ -518,6 +602,8 @@ def expand(
518
602
519
603
if isinstance (self ._root_detector , EnsembleAnomalyDetector ):
520
604
keyed_output = (keyed_input | RunEnsembleDetector (self ._root_detector ))
605
+ elif isinstance (self ._root_detector , OfflineDetector ):
606
+ keyed_output = (keyed_input | RunOfflineDetector (self ._root_detector ))
521
607
else :
522
608
keyed_output = (keyed_input | RunOneDetector (self ._root_detector ))
523
609
0 commit comments