Skip to content

Commit 3828148

Browse files
authored
[AnomalyDetection] Support offline detectors (#34311)
* Add offline detctors. * Fix typo.
1 parent e0b9d03 commit 3828148

File tree

3 files changed

+348
-6
lines changed

3 files changed

+348
-6
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Any
19+
from typing import Dict
20+
from typing import Optional
21+
22+
import apache_beam as beam
23+
from apache_beam.ml.anomaly.base import AnomalyDetector
24+
from apache_beam.ml.anomaly.specifiable import specifiable
25+
from apache_beam.ml.inference.base import KeyedModelHandler
26+
27+
28+
@specifiable
29+
class OfflineDetector(AnomalyDetector):
30+
"""A offline anomaly detector that uses a provided model handler for scoring.
31+
32+
Args:
33+
keyed_model_handler: The model handler to use for inference.
34+
Requires a `KeyModelHandler[Any, Row, float, Any]` instance.
35+
run_inference_args: Optional arguments to pass to RunInference
36+
**kwargs: Additional keyword arguments to pass to the base
37+
AnomalyDetector class.
38+
"""
39+
def __init__(
40+
self,
41+
keyed_model_handler: KeyedModelHandler[Any, beam.Row, float, Any],
42+
run_inference_args: Optional[Dict[str, Any]] = None,
43+
**kwargs):
44+
super().__init__(**kwargs)
45+
46+
# TODO: validate the model handler type
47+
self._keyed_model_handler = keyed_model_handler
48+
self._run_inference_args = run_inference_args or {}
49+
50+
# always override model_identifier with model_id from the detector
51+
self._run_inference_args["model_identifier"] = self._model_id
52+
53+
def learn_one(self, x: beam.Row) -> None:
54+
"""Not implemented since OfflineDetector invokes RunInference directly."""
55+
raise NotImplementedError
56+
57+
def score_one(self, x: beam.Row) -> Optional[float]:
58+
"""Not implemented since OfflineDetector invokes RunInference directly."""
59+
raise NotImplementedError

sdks/python/apache_beam/ml/anomaly/transforms.py

+92-6
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
import dataclasses
1919
import uuid
20+
from typing import Any
2021
from typing import Callable
2122
from typing import Dict
2223
from typing import Iterable
24+
from typing import List
2325
from typing import Optional
2426
from typing import Tuple
2527
from typing import TypeVar
@@ -33,8 +35,10 @@
3335
from apache_beam.ml.anomaly.base import AnomalyResult
3436
from apache_beam.ml.anomaly.base import EnsembleAnomalyDetector
3537
from apache_beam.ml.anomaly.base import ThresholdFn
38+
from apache_beam.ml.anomaly.detectors.offline import OfflineDetector
3639
from apache_beam.ml.anomaly.specifiable import Spec
3740
from apache_beam.ml.anomaly.specifiable import Specifiable
41+
from apache_beam.ml.inference.base import RunInference
3842
from apache_beam.transforms.userstate import ReadModifyWriteStateSpec
3943

4044
KeyT = TypeVar('KeyT')
@@ -97,9 +101,11 @@ def process(
97101
yield k1, (k2,
98102
AnomalyResult(
99103
example=data,
100-
predictions=[AnomalyPrediction(
101-
model_id=self._underlying._model_id,
102-
score=self.score_and_learn(data))]))
104+
predictions=[
105+
AnomalyPrediction(
106+
model_id=self._underlying._model_id,
107+
score=self.score_and_learn(data))
108+
]))
103109

104110
model_state.write(self._underlying)
105111

@@ -325,7 +331,8 @@ def expand(
325331
if self._aggregation_fn is None:
326332
# simply put predictions into an iterable (list)
327333
ret = (
328-
post_gbk | beam.MapTuple(
334+
post_gbk
335+
| beam.MapTuple(
329336
lambda k,
330337
v: (
331338
k[0],
@@ -353,7 +360,8 @@ def expand(
353360
# We use (original_key, temp_key) as the key for GroupByKey() so that
354361
# scores from multiple detectors per data point are grouped.
355362
ret = (
356-
post_gbk | beam.MapTuple(
363+
post_gbk
364+
| beam.MapTuple(
357365
lambda k,
358366
v,
359367
agg=aggregation_fn: (
@@ -406,6 +414,76 @@ def expand(
406414
return ret
407415

408416

417+
class RunOfflineDetector(beam.PTransform[beam.PCollection[KeyedInputT],
418+
beam.PCollection[KeyedOutputT]]):
419+
"""Runs a offline anomaly detector on a PCollection of data.
420+
421+
This PTransform applies a `OfflineDetector` to the input data, handling
422+
custom input/output conversion and inference.
423+
424+
Args:
425+
offline_detector: The `OfflineDetector` to run.
426+
"""
427+
def __init__(self, offline_detector: OfflineDetector):
428+
self._offline_detector = offline_detector
429+
430+
def unnest_and_convert(
431+
self, nested: Tuple[Tuple[Any, Any], dict[str, List]]) -> KeyedOutputT:
432+
"""Unnests and converts the model output to AnomalyResult.
433+
434+
Args:
435+
nested: A tuple containing the combined key (origin key, temp key) and
436+
a dictionary of input and output from RunInference.
437+
438+
Returns:
439+
A tuple containing the original key and AnomalyResult.
440+
"""
441+
key, value_dict = nested
442+
score = value_dict['output'][0]
443+
result = AnomalyResult(
444+
example=value_dict['input'][0],
445+
predictions=[
446+
AnomalyPrediction(
447+
model_id=self._offline_detector._model_id, score=score)
448+
])
449+
return key[0], (key[1], result)
450+
451+
def expand(
452+
self,
453+
input: beam.PCollection[KeyedInputT]) -> beam.PCollection[KeyedOutputT]:
454+
model_uuid = f"{self._offline_detector._model_id}:{uuid.uuid4().hex[:6]}"
455+
456+
# Call RunInference Transform with the keyed model handler
457+
run_inference = RunInference(
458+
self._offline_detector._keyed_model_handler,
459+
**self._offline_detector._run_inference_args)
460+
461+
# ((orig_key, temp_key), beam.Row)
462+
rekeyed_model_input = input | "Rekey" >> beam.Map(
463+
lambda x: ((x[0], x[1][0]), x[1][1]))
464+
465+
# ((orig_key, temp_key), float)
466+
rekeyed_model_output = (
467+
rekeyed_model_input
468+
| f"Call RunInference ({model_uuid})" >> run_inference)
469+
470+
# ((orig_key, temp_key), {'input':[row], 'output:[float]})
471+
rekeyed_cogbk = {
472+
'input': rekeyed_model_input, 'output': rekeyed_model_output
473+
} | beam.CoGroupByKey()
474+
475+
ret = (
476+
rekeyed_cogbk |
477+
"Unnest and convert model output" >> beam.Map(self.unnest_and_convert))
478+
479+
if self._offline_detector._threshold_criterion:
480+
ret = (
481+
ret | f"Run Threshold Criterion ({model_uuid})" >>
482+
RunThresholdCriterion(self._offline_detector._threshold_criterion))
483+
484+
return ret
485+
486+
409487
class RunEnsembleDetector(beam.PTransform[beam.PCollection[KeyedInputT],
410488
beam.PCollection[KeyedOutputT]]):
411489
"""Runs an ensemble of anomaly detectors on a PCollection of data.
@@ -432,8 +510,14 @@ def expand(
432510
for idx, detector in enumerate(self._ensemble_detector._sub_detectors):
433511
if isinstance(detector, EnsembleAnomalyDetector):
434512
results.append(
435-
input | f"Run Ensemble Detector at index {idx} ({model_uuid})" >>
513+
input
514+
| f"Run Ensemble Detector at index {idx} ({model_uuid})" >>
436515
RunEnsembleDetector(detector))
516+
elif isinstance(detector, OfflineDetector):
517+
results.append(
518+
input
519+
| f"Run Offline Detector at index {idx} ({model_uuid})" >>
520+
RunOfflineDetector(detector))
437521
else:
438522
results.append(
439523
input
@@ -518,6 +602,8 @@ def expand(
518602

519603
if isinstance(self._root_detector, EnsembleAnomalyDetector):
520604
keyed_output = (keyed_input | RunEnsembleDetector(self._root_detector))
605+
elif isinstance(self._root_detector, OfflineDetector):
606+
keyed_output = (keyed_input | RunOfflineDetector(self._root_detector))
521607
else:
522608
keyed_output = (keyed_input | RunOneDetector(self._root_detector))
523609

0 commit comments

Comments
 (0)