From 8405014ffec184c291b704af60446a587d84e9a1 Mon Sep 17 00:00:00 2001 From: vyzo <vyzo@hackzen.org> Date: Thu, 17 Jan 2019 14:44:34 +0200 Subject: [PATCH] extend validator interface to include message source --- floodsub_test.go | 6 +++--- pubsub.go | 18 +++++++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/floodsub_test.go b/floodsub_test.go index 3bf446d6..d9fb7326 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -357,7 +357,7 @@ func TestRegisterUnregisterValidator(t *testing.T) { hosts := getNetHosts(t, ctx, 1) psubs := getPubsubs(ctx, hosts) - err := psubs[0].RegisterTopicValidator("foo", func(context.Context, *Message) bool { + err := psubs[0].RegisterTopicValidator("foo", func(context.Context, peer.ID, *Message) bool { return true }) if err != nil { @@ -385,7 +385,7 @@ func TestValidate(t *testing.T) { connect(t, hosts[0], hosts[1]) topic := "foobar" - err := psubs[1].RegisterTopicValidator(topic, func(ctx context.Context, msg *Message) bool { + err := psubs[1].RegisterTopicValidator(topic, func(ctx context.Context, from peer.ID, msg *Message) bool { return !bytes.Contains(msg.Data, []byte("illegal")) }) if err != nil { @@ -482,7 +482,7 @@ func TestValidateOverload(t *testing.T) { block := make(chan struct{}) err := psubs[1].RegisterTopicValidator(topic, - func(ctx context.Context, msg *Message) bool { + func(ctx context.Context, from peer.ID, msg *Message) bool { <-block return true }, diff --git a/pubsub.go b/pubsub.go index 3c5cebe2..b072826f 100644 --- a/pubsub.go +++ b/pubsub.go @@ -661,7 +661,7 @@ func (p *PubSub) validate(vals []*topicVal, src peer.ID, msg *Message) { } if len(vals) > 0 { - if !p.validateTopic(vals, msg) { + if !p.validateTopic(vals, src, msg) { log.Warningf("message validation failed; dropping message from %s", src) return } @@ -684,9 +684,9 @@ func (p *PubSub) validateSignature(msg *Message) bool { return true } -func (p *PubSub) validateTopic(vals []*topicVal, msg *Message) bool { +func (p *PubSub) validateTopic(vals []*topicVal, src peer.ID, msg *Message) bool { if len(vals) == 1 { - return p.validateSingleTopic(vals[0], msg) + return p.validateSingleTopic(vals[0], src, msg) } ctx, cancel := context.WithCancel(p.ctx) @@ -703,7 +703,7 @@ loop: select { case val.validateThrottle <- struct{}{}: go func(val *topicVal) { - rch <- val.validateMsg(ctx, msg) + rch <- val.validateMsg(ctx, src, msg) <-val.validateThrottle }(val) @@ -729,13 +729,13 @@ loop: } // fast path for single topic validation that avoids the extra goroutine -func (p *PubSub) validateSingleTopic(val *topicVal, msg *Message) bool { +func (p *PubSub) validateSingleTopic(val *topicVal, src peer.ID, msg *Message) bool { select { case val.validateThrottle <- struct{}{}: ctx, cancel := context.WithCancel(p.ctx) defer cancel() - res := val.validateMsg(ctx, msg) + res := val.validateMsg(ctx, src, msg) <-val.validateThrottle return res @@ -900,7 +900,7 @@ type topicVal struct { } // Validator is a function that validates a message. -type Validator func(context.Context, *Message) bool +type Validator func(context.Context, peer.ID, *Message) bool // ValidatorOpt is an option for RegisterTopicValidator. type ValidatorOpt func(addVal *addValReq) error @@ -992,11 +992,11 @@ func (ps *PubSub) rmValidator(req *rmValReq) { } } -func (val *topicVal) validateMsg(ctx context.Context, msg *Message) bool { +func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message) bool { vctx, cancel := context.WithTimeout(ctx, val.validateTimeout) defer cancel() - valid := val.validate(vctx, msg) + valid := val.validate(vctx, src, msg) if !valid { log.Debugf("validation failed for topic %s", val.topic) }