Skip to content

Commit a23489b

Browse files
authored
Set namespace on API if not present (#3953)
1 parent 8917ecc commit a23489b

File tree

3 files changed

+169
-8
lines changed

3 files changed

+169
-8
lines changed

common/rpc/interceptor/namespace_validator.go

+66-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ import (
4040
)
4141

4242
type (
43-
// NamespaceValidatorInterceptor contains LengthValidationIntercept and StateValidationIntercept
43+
TaskTokenGetter interface {
44+
GetTaskToken() []byte
45+
}
46+
47+
// NamespaceValidatorInterceptor contains NamespaceValidateIntercept and StateValidationIntercept
4448
NamespaceValidatorInterceptor struct {
4549
namespaceRegistry namespace.Registry
4650
tokenSerializer common.TaskTokenSerializer
@@ -71,7 +75,7 @@ var (
7175
)
7276

7377
var _ grpc.UnaryServerInterceptor = (*NamespaceValidatorInterceptor)(nil).StateValidationIntercept
74-
var _ grpc.UnaryServerInterceptor = (*NamespaceValidatorInterceptor)(nil).LengthValidationIntercept
78+
var _ grpc.UnaryServerInterceptor = (*NamespaceValidatorInterceptor)(nil).NamespaceValidateIntercept
7579

7680
func NewNamespaceValidatorInterceptor(
7781
namespaceRegistry namespace.Registry,
@@ -86,12 +90,16 @@ func NewNamespaceValidatorInterceptor(
8690
}
8791
}
8892

89-
func (ni *NamespaceValidatorInterceptor) LengthValidationIntercept(
93+
func (ni *NamespaceValidatorInterceptor) NamespaceValidateIntercept(
9094
ctx context.Context,
9195
req interface{},
9296
info *grpc.UnaryServerInfo,
9397
handler grpc.UnaryHandler,
9498
) (interface{}, error) {
99+
err := ni.setNamespaceIfNotPresent(req)
100+
if err != nil {
101+
return nil, err
102+
}
95103
reqWithNamespace, hasNamespace := req.(NamespaceNameGetter)
96104
if hasNamespace {
97105
namespaceName := namespace.Name(reqWithNamespace.GetNamespace())
@@ -103,6 +111,60 @@ func (ni *NamespaceValidatorInterceptor) LengthValidationIntercept(
103111
return handler(ctx, req)
104112
}
105113

114+
func (ni *NamespaceValidatorInterceptor) setNamespaceIfNotPresent(
115+
req interface{},
116+
) error {
117+
switch request := req.(type) {
118+
case NamespaceNameGetter:
119+
if request.GetNamespace() == "" {
120+
namespaceEntry, err := ni.extractNamespaceFromTaskToken(req)
121+
if err != nil {
122+
return err
123+
}
124+
ni.setNamespace(namespaceEntry, req)
125+
}
126+
return nil
127+
default:
128+
return nil
129+
}
130+
}
131+
132+
func (ni *NamespaceValidatorInterceptor) setNamespace(
133+
namespaceEntry *namespace.Namespace,
134+
req interface{},
135+
) {
136+
switch request := req.(type) {
137+
case *workflowservice.RespondQueryTaskCompletedRequest:
138+
if request.Namespace == "" {
139+
request.Namespace = namespaceEntry.Name().String()
140+
}
141+
case *workflowservice.RespondWorkflowTaskCompletedRequest:
142+
if request.Namespace == "" {
143+
request.Namespace = namespaceEntry.Name().String()
144+
}
145+
case *workflowservice.RespondWorkflowTaskFailedRequest:
146+
if request.Namespace == "" {
147+
request.Namespace = namespaceEntry.Name().String()
148+
}
149+
case *workflowservice.RecordActivityTaskHeartbeatRequest:
150+
if request.Namespace == "" {
151+
request.Namespace = namespaceEntry.Name().String()
152+
}
153+
case *workflowservice.RespondActivityTaskCanceledRequest:
154+
if request.Namespace == "" {
155+
request.Namespace = namespaceEntry.Name().String()
156+
}
157+
case *workflowservice.RespondActivityTaskCompletedRequest:
158+
if request.Namespace == "" {
159+
request.Namespace = namespaceEntry.Name().String()
160+
}
161+
case *workflowservice.RespondActivityTaskFailedRequest:
162+
if request.Namespace == "" {
163+
request.Namespace = namespaceEntry.Name().String()
164+
}
165+
}
166+
}
167+
106168
// StateValidationIntercept validates:
107169
// 1. Namespace is specified in task token if there is a `task_token` field.
108170
// 2. Namespace is specified in request if there is a `namespace` field and no `task_token` field.
@@ -202,7 +264,7 @@ func (ni *NamespaceValidatorInterceptor) extractNamespaceFromRequest(req interfa
202264
}
203265

204266
func (ni *NamespaceValidatorInterceptor) extractNamespaceFromTaskToken(req interface{}) (*namespace.Namespace, error) {
205-
reqWithTaskToken, hasTaskToken := req.(interface{ GetTaskToken() []byte })
267+
reqWithTaskToken, hasTaskToken := req.(TaskTokenGetter)
206268
if !hasTaskToken {
207269
return nil, nil
208270
}

common/rpc/interceptor/namespace_validator_test.go

+102-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"testing"
3131

3232
"github.com/golang/mock/gomock"
33+
"github.com/google/uuid"
3334
"github.com/stretchr/testify/require"
3435
"github.com/stretchr/testify/suite"
3536
enumspb "go.temporal.io/api/enums/v1"
@@ -684,18 +685,44 @@ func (s *namespaceValidatorSuite) Test_Intercept_SearchAttributeRequests() {
684685
}
685686
}
686687

687-
func (s *namespaceValidatorSuite) Test_LengthValidationIntercept() {
688+
func (s *namespaceValidatorSuite) Test_NamespaceValidateIntercept() {
688689
nvi := NewNamespaceValidatorInterceptor(
689690
s.mockRegistry,
690691
dynamicconfig.GetBoolPropertyFn(false),
691692
dynamicconfig.GetIntPropertyFn(10))
692693
serverInfo := &grpc.UnaryServerInfo{
693694
FullMethod: "/temporal/random",
694695
}
696+
requestNamespace := namespace.FromPersistentState(
697+
&persistence.GetNamespaceResponse{
698+
Namespace: &persistencespb.NamespaceDetail{
699+
Config: &persistencespb.NamespaceConfig{},
700+
ReplicationConfig: &persistencespb.NamespaceReplicationConfig{},
701+
Info: &persistencespb.NamespaceInfo{
702+
Id: uuid.New().String(),
703+
Name: "namespace",
704+
State: enumspb.NAMESPACE_STATE_REGISTERED,
705+
},
706+
},
707+
})
708+
requestNamespaceTooLong := namespace.FromPersistentState(
709+
&persistence.GetNamespaceResponse{
710+
Namespace: &persistencespb.NamespaceDetail{
711+
Config: &persistencespb.NamespaceConfig{},
712+
ReplicationConfig: &persistencespb.NamespaceReplicationConfig{},
713+
Info: &persistencespb.NamespaceInfo{
714+
Id: uuid.New().String(),
715+
Name: "namespaceTooLong",
716+
State: enumspb.NAMESPACE_STATE_REGISTERED,
717+
},
718+
},
719+
})
720+
s.mockRegistry.EXPECT().GetNamespace(namespace.Name("namespace")).Return(requestNamespace, nil).AnyTimes()
721+
s.mockRegistry.EXPECT().GetNamespace(namespace.Name("namespaceTooLong")).Return(requestNamespaceTooLong, nil).AnyTimes()
695722

696723
req := &workflowservice.StartWorkflowExecutionRequest{Namespace: "namespace"}
697724
handlerCalled := false
698-
_, err := nvi.LengthValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
725+
_, err := nvi.NamespaceValidateIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
699726
handlerCalled = true
700727
return &workflowservice.StartWorkflowExecutionResponse{}, nil
701728
})
@@ -704,10 +731,82 @@ func (s *namespaceValidatorSuite) Test_LengthValidationIntercept() {
704731

705732
req = &workflowservice.StartWorkflowExecutionRequest{Namespace: "namespaceTooLong"}
706733
handlerCalled = false
707-
_, err = nvi.LengthValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
734+
_, err = nvi.NamespaceValidateIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
708735
handlerCalled = true
709736
return &workflowservice.StartWorkflowExecutionResponse{}, nil
710737
})
711738
s.False(handlerCalled)
712739
s.Error(err)
713740
}
741+
742+
func (s *namespaceValidatorSuite) TestSetNamespace() {
743+
namespaceRequestName := uuid.New().String()
744+
namespaceEntryName := uuid.New().String()
745+
namespaceEntry := namespace.FromPersistentState(
746+
&persistence.GetNamespaceResponse{
747+
Namespace: &persistencespb.NamespaceDetail{
748+
Config: &persistencespb.NamespaceConfig{},
749+
ReplicationConfig: &persistencespb.NamespaceReplicationConfig{},
750+
Info: &persistencespb.NamespaceInfo{
751+
Id: uuid.New().String(),
752+
Name: namespaceEntryName,
753+
State: enumspb.NAMESPACE_STATE_REGISTERED,
754+
},
755+
},
756+
})
757+
758+
nvi := NewNamespaceValidatorInterceptor(
759+
s.mockRegistry,
760+
dynamicconfig.GetBoolPropertyFn(false),
761+
dynamicconfig.GetIntPropertyFn(10),
762+
)
763+
764+
queryReq := &workflowservice.RespondQueryTaskCompletedRequest{}
765+
nvi.setNamespace(namespaceEntry, queryReq)
766+
s.Equal(namespaceEntryName, queryReq.Namespace)
767+
queryReq.Namespace = namespaceRequestName
768+
nvi.setNamespace(namespaceEntry, queryReq)
769+
s.Equal(namespaceRequestName, queryReq.Namespace)
770+
771+
completeWorkflowTaskReq := &workflowservice.RespondWorkflowTaskCompletedRequest{}
772+
nvi.setNamespace(namespaceEntry, completeWorkflowTaskReq)
773+
s.Equal(namespaceEntryName, completeWorkflowTaskReq.Namespace)
774+
completeWorkflowTaskReq.Namespace = namespaceRequestName
775+
nvi.setNamespace(namespaceEntry, completeWorkflowTaskReq)
776+
s.Equal(namespaceRequestName, completeWorkflowTaskReq.Namespace)
777+
778+
failWorkflowTaskReq := &workflowservice.RespondWorkflowTaskFailedRequest{}
779+
nvi.setNamespace(namespaceEntry, failWorkflowTaskReq)
780+
s.Equal(namespaceEntryName, failWorkflowTaskReq.Namespace)
781+
failWorkflowTaskReq.Namespace = namespaceRequestName
782+
nvi.setNamespace(namespaceEntry, failWorkflowTaskReq)
783+
s.Equal(namespaceRequestName, failWorkflowTaskReq.Namespace)
784+
785+
heartbeatActivityTaskReq := &workflowservice.RecordActivityTaskHeartbeatRequest{}
786+
nvi.setNamespace(namespaceEntry, heartbeatActivityTaskReq)
787+
s.Equal(namespaceEntryName, heartbeatActivityTaskReq.Namespace)
788+
heartbeatActivityTaskReq.Namespace = namespaceRequestName
789+
nvi.setNamespace(namespaceEntry, heartbeatActivityTaskReq)
790+
s.Equal(namespaceRequestName, heartbeatActivityTaskReq.Namespace)
791+
792+
cancelActivityTaskReq := &workflowservice.RespondActivityTaskCanceledRequest{}
793+
nvi.setNamespace(namespaceEntry, cancelActivityTaskReq)
794+
s.Equal(namespaceEntryName, cancelActivityTaskReq.Namespace)
795+
cancelActivityTaskReq.Namespace = namespaceRequestName
796+
nvi.setNamespace(namespaceEntry, cancelActivityTaskReq)
797+
s.Equal(namespaceRequestName, cancelActivityTaskReq.Namespace)
798+
799+
completeActivityTaskReq := &workflowservice.RespondActivityTaskCompletedRequest{}
800+
nvi.setNamespace(namespaceEntry, completeActivityTaskReq)
801+
s.Equal(namespaceEntryName, completeActivityTaskReq.Namespace)
802+
completeActivityTaskReq.Namespace = namespaceRequestName
803+
nvi.setNamespace(namespaceEntry, completeActivityTaskReq)
804+
s.Equal(namespaceRequestName, completeActivityTaskReq.Namespace)
805+
806+
failActivityTaskReq := &workflowservice.RespondActivityTaskFailedRequest{}
807+
nvi.setNamespace(namespaceEntry, failActivityTaskReq)
808+
s.Equal(namespaceEntryName, failActivityTaskReq.Namespace)
809+
failActivityTaskReq.Namespace = namespaceRequestName
810+
nvi.setNamespace(namespaceEntry, failActivityTaskReq)
811+
s.Equal(namespaceRequestName, failActivityTaskReq.Namespace)
812+
}

service/frontend/fx.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ func GrpcServerOptionsProvider(
178178
interceptors := []grpc.UnaryServerInterceptor{
179179
// Service Error Interceptor should be the most outer interceptor on error handling
180180
rpc.ServiceErrorInterceptor,
181-
namespaceValidatorInterceptor.LengthValidationIntercept,
181+
namespaceValidatorInterceptor.NamespaceValidateIntercept,
182182
namespaceLogInterceptor.Intercept, // TODO: Deprecate this with a outer custom interceptor
183183
grpc.UnaryServerInterceptor(traceInterceptor),
184184
metrics.NewServerMetricsContextInjectorInterceptor(),

0 commit comments

Comments
 (0)