Skip to content

Commit d23da08

Browse files
claudevdmClaude
authored andcommitted
Deepcopy combine_fn in PrecombineFn and PostCombineFn. (apache#32598)
Co-authored-by: Claude <cvandermerwe@google.com>
1 parent 6622f7d commit d23da08

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,18 @@ def test_combining_value_state(self):
5353

5454

5555
@parameterized_class([
56-
{'runner': direct_runner.BundleBasedDirectRunner},
57-
{'runner': fn_api_runner.FnApiRunner},
58-
]) # yapf: disable
56+
{'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'dill'},
57+
{'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'cloudpickle'},
58+
{'runner': fn_api_runner.FnApiRunner, 'pickler': 'dill'},
59+
{'runner': fn_api_runner.FnApiRunner, 'pickler': 'cloudpickle'},
60+
]) # yapf: disable
5961
class LocalCombineFnLifecycleTest(unittest.TestCase):
6062
def tearDown(self):
6163
CallSequenceEnforcingCombineFn.instances.clear()
6264

6365
def test_combine(self):
64-
run_combine(TestPipeline(runner=self.runner()))
66+
test_options = PipelineOptions(flags=[f"--pickle_library={self.pickler}"])
67+
run_combine(TestPipeline(runner=self.runner(), options=test_options))
6568
self._assert_teardown_called()
6669

6770
def test_non_liftable_combine(self):

sdks/python/apache_beam/transforms/core.py

+25-18
Original file line numberDiff line numberDiff line change
@@ -3147,33 +3147,40 @@ def process(self, element):
31473147
yield pvalue.TaggedOutput('hot', ((self._nonce % fanout, key), value))
31483148

31493149
class PreCombineFn(CombineFn):
3150+
def __init__(self):
3151+
# Deepcopy of the combine_fn to avoid sharing state between lifted
3152+
# stages when using cloudpickle.
3153+
self._combine_fn_copy = copy.deepcopy(combine_fn)
3154+
self.setup = self._combine_fn_copy.setup
3155+
self.create_accumulator = self._combine_fn_copy.create_accumulator
3156+
self.add_input = self._combine_fn_copy.add_input
3157+
self.merge_accumulators = self._combine_fn_copy.merge_accumulators
3158+
self.compact = self._combine_fn_copy.compact
3159+
self.teardown = self._combine_fn_copy.teardown
3160+
31503161
@staticmethod
31513162
def extract_output(accumulator):
31523163
# Boolean indicates this is an accumulator.
31533164
return (True, accumulator)
31543165

3155-
setup = combine_fn.setup
3156-
create_accumulator = combine_fn.create_accumulator
3157-
add_input = combine_fn.add_input
3158-
merge_accumulators = combine_fn.merge_accumulators
3159-
compact = combine_fn.compact
3160-
teardown = combine_fn.teardown
3161-
31623166
class PostCombineFn(CombineFn):
3163-
@staticmethod
3164-
def add_input(accumulator, element):
3167+
def __init__(self):
3168+
# Deepcopy of the combine_fn to avoid sharing state between lifted
3169+
# stages when using cloudpickle.
3170+
self._combine_fn_copy = copy.deepcopy(combine_fn)
3171+
self.setup = self._combine_fn_copy.setup
3172+
self.create_accumulator = self._combine_fn_copy.create_accumulator
3173+
self.merge_accumulators = self._combine_fn_copy.merge_accumulators
3174+
self.compact = self._combine_fn_copy.compact
3175+
self.extract_output = self._combine_fn_copy.extract_output
3176+
self.teardown = self._combine_fn_copy.teardown
3177+
3178+
def add_input(self, accumulator, element):
31653179
is_accumulator, value = element
31663180
if is_accumulator:
3167-
return combine_fn.merge_accumulators([accumulator, value])
3181+
return self._combine_fn_copy.merge_accumulators([accumulator, value])
31683182
else:
3169-
return combine_fn.add_input(accumulator, value)
3170-
3171-
setup = combine_fn.setup
3172-
create_accumulator = combine_fn.create_accumulator
3173-
merge_accumulators = combine_fn.merge_accumulators
3174-
compact = combine_fn.compact
3175-
extract_output = combine_fn.extract_output
3176-
teardown = combine_fn.teardown
3183+
return self._combine_fn_copy.add_input(accumulator, value)
31773184

31783185
def StripNonce(nonce_key_value):
31793186
(_, key), value = nonce_key_value

0 commit comments

Comments
 (0)