Skip to content

Commit 430d3a9

Browse files
authored
Messages protocol implementation (#3843)
1 parent e6113da commit 430d3a9

14 files changed

+856
-583
lines changed

api/historyservice/v1/request_response.pb.go

+451-367
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

api/matchingservice/v1/request_response.pb.go

+247-163
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

common/util.go

+8-24
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ import (
3535

3636
"github.com/dgryski/go-farm"
3737
"github.com/gogo/protobuf/proto"
38-
commandpb "go.temporal.io/api/command/v1"
3938
commonpb "go.temporal.io/api/common/v1"
4039
enumspb "go.temporal.io/api/enums/v1"
41-
historypb "go.temporal.io/api/history/v1"
4240
"go.temporal.io/api/serviceerror"
4341
"go.temporal.io/api/workflowservice/v1"
4442

@@ -388,31 +386,16 @@ func WorkflowIDToHistoryShard(
388386
return int32(hash%uint32(numberOfShards)) + 1 // ShardID starts with 1
389387
}
390388

391-
// PrettyPrintHistory prints history in human-readable format
392-
func PrettyPrintHistory(history *historypb.History, header ...string) {
389+
func PrettyPrint[T proto.Message](msgs []T, header ...string) {
393390
var sb strings.Builder
394-
sb.WriteString("==========================================================================\n")
391+
_, _ = sb.WriteString("==========================================================================\n")
395392
for _, h := range header {
396-
sb.WriteString(h)
397-
sb.WriteString("\n")
393+
_, _ = sb.WriteString(h)
394+
_, _ = sb.WriteString("\n")
398395
}
399-
sb.WriteString("--------------------------------------------------------------------------\n")
400-
_ = proto.MarshalText(&sb, history)
401-
sb.WriteString("\n")
402-
fmt.Print(sb.String())
403-
}
404-
405-
// PrettyPrintCommands prints commands in human-readable format
406-
func PrettyPrintCommands(commands []*commandpb.Command, header ...string) {
407-
var sb strings.Builder
408-
sb.WriteString("==========================================================================\n")
409-
for _, h := range header {
410-
sb.WriteString(h)
411-
sb.WriteString("\n")
412-
}
413-
sb.WriteString("--------------------------------------------------------------------------\n")
414-
for _, command := range commands {
415-
_ = proto.MarshalText(&sb, command)
396+
_, _ = sb.WriteString("--------------------------------------------------------------------------\n")
397+
for _, m := range msgs {
398+
_ = proto.MarshalText(&sb, m)
416399
}
417400
fmt.Print(sb.String())
418401
}
@@ -465,6 +448,7 @@ func CreateMatchingPollWorkflowTaskQueueResponse(historyResponse *historyservice
465448
ScheduledTime: historyResponse.ScheduledTime,
466449
StartedTime: historyResponse.StartedTime,
467450
Queries: historyResponse.Queries,
451+
Messages: historyResponse.Messages,
468452
}
469453

470454
return matchingResp

proto/internal/temporal/server/api/historyservice/v1/request_response.proto

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import "temporal/api/taskqueue/v1/message.proto";
3434
import "temporal/api/enums/v1/workflow.proto";
3535
import "temporal/api/workflow/v1/message.proto";
3636
import "temporal/api/query/v1/message.proto";
37+
import "temporal/api/protocol/v1/message.proto";
3738
import "temporal/api/failure/v1/message.proto";
3839

3940
import "temporal/server/api/clock/v1/message.proto";
@@ -164,6 +165,7 @@ message RecordWorkflowTaskStartedResponse {
164165
google.protobuf.Timestamp started_time = 13 [(gogoproto.stdtime) = true];
165166
map<string, temporal.api.query.v1.WorkflowQuery> queries = 14;
166167
temporal.server.api.clock.v1.VectorClock clock = 15;
168+
repeated temporal.api.protocol.v1.Message messages = 16;
167169
}
168170

169171
message RecordActivityTaskStartedRequest {

proto/internal/temporal/server/api/matchingservice/v1/request_response.proto

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import "temporal/api/common/v1/message.proto";
3232
import "temporal/api/enums/v1/task_queue.proto";
3333
import "temporal/api/taskqueue/v1/message.proto";
3434
import "temporal/api/query/v1/message.proto";
35+
import "temporal/api/protocol/v1/message.proto";
3536

3637
import "temporal/server/api/clock/v1/message.proto";
3738
import "temporal/server/api/enums/v1/task.proto";
@@ -65,6 +66,7 @@ message PollWorkflowTaskQueueResponse {
6566
google.protobuf.Timestamp scheduled_time = 15 [(gogoproto.stdtime) = true];
6667
google.protobuf.Timestamp started_time = 16 [(gogoproto.stdtime) = true];
6768
map<string, temporal.api.query.v1.WorkflowQuery> queries = 17;
69+
repeated temporal.api.protocol.v1.Message messages = 18;
6870
}
6971

7072
message PollActivityTaskQueueRequest {

service/frontend/workflow_handler.go

+1
Original file line numberDiff line numberDiff line change
@@ -4433,6 +4433,7 @@ func (wh *WorkflowHandler) createPollWorkflowTaskQueueResponse(
44334433
ScheduledTime: matchingResp.ScheduledTime,
44344434
StartedTime: matchingResp.StartedTime,
44354435
Queries: matchingResp.Queries,
4436+
Messages: matchingResp.Messages,
44364437
}
44374438

44384439
return resp, nil

service/history/commandChecker.go

+9
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import (
3333
commandpb "go.temporal.io/api/command/v1"
3434
commonpb "go.temporal.io/api/common/v1"
3535
enumspb "go.temporal.io/api/enums/v1"
36+
protocolpb "go.temporal.io/api/protocol/v1"
3637
"go.temporal.io/api/serviceerror"
3738
taskqueuepb "go.temporal.io/api/taskqueue/v1"
3839

@@ -898,3 +899,11 @@ func (v *commandAttrValidator) commandTypes(
898899
}
899900
return result
900901
}
902+
903+
// TODO (alex-update): move to messageValidator.
904+
func (v *commandAttrValidator) validateMessages(
905+
_ []*protocolpb.Message,
906+
) error {
907+
908+
return nil
909+
}

service/history/workflowTaskHandler.go

+26
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import (
3434
commonpb "go.temporal.io/api/common/v1"
3535
enumspb "go.temporal.io/api/enums/v1"
3636
failurepb "go.temporal.io/api/failure/v1"
37+
protocolpb "go.temporal.io/api/protocol/v1"
3738
"go.temporal.io/api/serviceerror"
3839
"go.temporal.io/api/workflowservice/v1"
3940

@@ -239,6 +240,31 @@ func (handler *workflowTaskHandlerImpl) handleCommand(ctx context.Context, comma
239240
}
240241
}
241242

243+
func (handler *workflowTaskHandlerImpl) handleMessages(
244+
ctx context.Context,
245+
messages []*protocolpb.Message,
246+
) error {
247+
if err := handler.attrValidator.validateMessages(
248+
messages,
249+
); err != nil {
250+
return err
251+
}
252+
253+
for _, message := range messages {
254+
err := handler.handleMessage(ctx, message)
255+
if err != nil || handler.stopProcessing {
256+
return err
257+
}
258+
}
259+
260+
return nil
261+
}
262+
263+
func (handler *workflowTaskHandlerImpl) handleMessage(_ context.Context, _ *protocolpb.Message) error {
264+
265+
return nil
266+
}
267+
242268
func (handler *workflowTaskHandlerImpl) handleCommandScheduleActivity(
243269
_ context.Context,
244270
attr *commandpb.ScheduleActivityTaskCommandAttributes,

service/history/workflowTaskHandlerCallbacks.go

+31-13
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,11 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
386386
metrics.OperationTag(metrics.HistoryRespondWorkflowTaskCompletedScope))
387387
}
388388

389-
workflowTaskHeartbeating := request.GetForceCreateNewWorkflowTask() && len(request.Commands) == 0
389+
workflowTaskHeartbeating := request.GetForceCreateNewWorkflowTask() && len(request.Commands) == 0 && len(request.Messages) == 0
390390
var workflowTaskHeartbeatTimeout bool
391391
var completedEvent *historypb.HistoryEvent
392+
var responseMutations []workflowTaskResponseMutation
393+
392394
if workflowTaskHeartbeating {
393395
namespace := namespaceEntry.Name()
394396
timeout := handler.config.WorkflowTaskHeartbeatTimeout(namespace.String())
@@ -423,11 +425,8 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
423425
wtFailedCause *workflowTaskFailedCause
424426
activityNotStartedCancelled bool
425427
newMutableState workflow.MutableState
426-
427-
hasUnhandledEvents bool
428-
responseMutations []workflowTaskResponseMutation
429428
)
430-
hasUnhandledEvents = ms.HasBufferedEvents()
429+
hasBufferedEvents := ms.HasBufferedEvents()
431430

432431
if request.StickyAttributes == nil || request.StickyAttributes.WorkerTaskQueue == nil {
433432
handler.metricsHandler.Counter(metrics.CompleteWorkflowTaskWithStickyDisabledCounter.GetMetricName()).Record(
@@ -481,7 +480,7 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
481480
handler.config,
482481
handler.shard,
483482
handler.searchAttributesMapper,
484-
hasUnhandledEvents,
483+
hasBufferedEvents,
485484
)
486485

487486
if responseMutations, err = workflowTaskHandler.handleCommands(
@@ -491,6 +490,13 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
491490
return nil, err
492491
}
493492

493+
if err = workflowTaskHandler.handleMessages(
494+
ctx,
495+
request.Messages,
496+
); err != nil {
497+
return nil, err
498+
}
499+
494500
// set the vars used by following logic
495501
// further refactor should also clean up the vars used below
496502
wtFailedCause = workflowTaskHandler.workflowTaskFailedCause
@@ -501,7 +507,7 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
501507

502508
newMutableState = workflowTaskHandler.newMutableState
503509

504-
hasUnhandledEvents = workflowTaskHandler.hasBufferedEvents
510+
hasBufferedEvents = workflowTaskHandler.hasBufferedEvents
505511
}
506512

507513
if wtFailedCause != nil {
@@ -522,7 +528,7 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
522528
if err != nil {
523529
return nil, err
524530
}
525-
hasUnhandledEvents = true
531+
hasBufferedEvents = true
526532
newMutableState = nil
527533

528534
if wtFailedCause.workflowFailure != nil {
@@ -532,24 +538,37 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
532538
if _, err := ms.AddFailWorkflowEvent(nextEventBatchId, enumspb.RETRY_STATE_NON_RETRYABLE_FAILURE, attributes, ""); err != nil {
533539
return nil, err
534540
}
535-
hasUnhandledEvents = false
541+
hasBufferedEvents = false
536542
}
537543
}
538544

539-
createNewWorkflowTask := ms.IsWorkflowExecutionRunning() && (hasUnhandledEvents || request.GetForceCreateNewWorkflowTask() || activityNotStartedCancelled)
545+
newWorkflowTaskType := enumsspb.WORKFLOW_TASK_TYPE_UNSPECIFIED
546+
if ms.IsWorkflowExecutionRunning() && (hasBufferedEvents || request.GetForceCreateNewWorkflowTask() || activityNotStartedCancelled) {
547+
newWorkflowTaskType = enumsspb.WORKFLOW_TASK_TYPE_NORMAL
548+
}
549+
createNewWorkflowTask := newWorkflowTaskType != enumsspb.WORKFLOW_TASK_TYPE_UNSPECIFIED
550+
540551
var newWorkflowTaskScheduledEventID int64
541552
if createNewWorkflowTask {
553+
// TODO (alex-update): Need to support case when ReturnNewWorkflowTask=false and WT.Type=Speculative.
554+
// In this case WT needs to be added directly to matching.
555+
// Current implementation will create normal WT.
542556
bypassTaskGeneration := request.GetReturnNewWorkflowTask() && wtFailedCause == nil
557+
if !bypassTaskGeneration {
558+
// If task generation can't be bypassed workflow task must be of Normal type because Speculative workflow task always skip task generation.
559+
newWorkflowTaskType = enumsspb.WORKFLOW_TASK_TYPE_NORMAL
560+
}
561+
543562
var newWorkflowTask *workflow.WorkflowTaskInfo
544563
var err error
545564
if workflowTaskHeartbeating && !workflowTaskHeartbeatTimeout {
546565
newWorkflowTask, err = ms.AddWorkflowTaskScheduledEventAsHeartbeat(
547566
bypassTaskGeneration,
548567
currentWorkflowTask.OriginalScheduledTime,
549-
enumsspb.WORKFLOW_TASK_TYPE_NORMAL,
568+
enumsspb.WORKFLOW_TASK_TYPE_NORMAL, // Heartbeat workflow task is always of Normal type.
550569
)
551570
} else {
552-
newWorkflowTask, err = ms.AddWorkflowTaskScheduledEvent(bypassTaskGeneration, enumsspb.WORKFLOW_TASK_TYPE_NORMAL)
571+
newWorkflowTask, err = ms.AddWorkflowTaskScheduledEvent(bypassTaskGeneration, newWorkflowTaskType)
553572
}
554573
if err != nil {
555574
return nil, err
@@ -661,7 +680,6 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
661680
}
662681

663682
return resp, nil
664-
665683
}
666684

667685
func (handler *workflowTaskHandlerCallbacksImpl) verifyFirstWorkflowTaskScheduled(

tests/integration_test.go

+21
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@ package tests
2626

2727
import (
2828
"flag"
29+
"reflect"
2930
"testing"
3031
"time"
3132

33+
"github.com/gogo/protobuf/proto"
34+
"github.com/gogo/protobuf/types"
3235
"github.com/stretchr/testify/require"
3336
"github.com/stretchr/testify/suite"
3437
commonpb "go.temporal.io/api/common/v1"
@@ -79,3 +82,21 @@ func (s *integrationSuite) sendSignal(namespace string, execution *commonpb.Work
7982

8083
return err
8184
}
85+
86+
func unmarshalAny[T proto.Message](s *integrationSuite, a *types.Any) T {
87+
s.T().Helper()
88+
pb := new(T)
89+
ppb := reflect.ValueOf(pb).Elem()
90+
pbNew := reflect.New(reflect.TypeOf(pb).Elem().Elem())
91+
ppb.Set(pbNew)
92+
err := types.UnmarshalAny(a, *pb)
93+
s.NoError(err)
94+
return *pb
95+
}
96+
97+
func marshalAny(s *integrationSuite, pb proto.Message) *types.Any {
98+
s.T().Helper()
99+
a, err := types.MarshalAny(pb)
100+
s.NoError(err)
101+
return a
102+
}

tests/integrationbase.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ func (s *IntegrationBase) randomizeStr(id string) string {
222222

223223
func (s *IntegrationBase) printWorkflowHistory(namespace string, execution *commonpb.WorkflowExecution) {
224224
events := s.getHistory(namespace, execution)
225-
common.PrettyPrintHistory(&historypb.History{Events: events})
225+
common.PrettyPrint(events)
226226
}
227227

228228
//lint:ignore U1000 used for debugging.

0 commit comments

Comments
 (0)