From 8405014ffec184c291b704af60446a587d84e9a1 Mon Sep 17 00:00:00 2001
From: vyzo <vyzo@hackzen.org>
Date: Thu, 17 Jan 2019 14:44:34 +0200
Subject: [PATCH] extend validator interface to include message source

---
 floodsub_test.go |  6 +++---
 pubsub.go        | 18 +++++++++---------
 2 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/floodsub_test.go b/floodsub_test.go
index 3bf446d6..d9fb7326 100644
--- a/floodsub_test.go
+++ b/floodsub_test.go
@@ -357,7 +357,7 @@ func TestRegisterUnregisterValidator(t *testing.T) {
 	hosts := getNetHosts(t, ctx, 1)
 	psubs := getPubsubs(ctx, hosts)
 
-	err := psubs[0].RegisterTopicValidator("foo", func(context.Context, *Message) bool {
+	err := psubs[0].RegisterTopicValidator("foo", func(context.Context, peer.ID, *Message) bool {
 		return true
 	})
 	if err != nil {
@@ -385,7 +385,7 @@ func TestValidate(t *testing.T) {
 	connect(t, hosts[0], hosts[1])
 	topic := "foobar"
 
-	err := psubs[1].RegisterTopicValidator(topic, func(ctx context.Context, msg *Message) bool {
+	err := psubs[1].RegisterTopicValidator(topic, func(ctx context.Context, from peer.ID, msg *Message) bool {
 		return !bytes.Contains(msg.Data, []byte("illegal"))
 	})
 	if err != nil {
@@ -482,7 +482,7 @@ func TestValidateOverload(t *testing.T) {
 		block := make(chan struct{})
 
 		err := psubs[1].RegisterTopicValidator(topic,
-			func(ctx context.Context, msg *Message) bool {
+			func(ctx context.Context, from peer.ID, msg *Message) bool {
 				<-block
 				return true
 			},
diff --git a/pubsub.go b/pubsub.go
index 3c5cebe2..b072826f 100644
--- a/pubsub.go
+++ b/pubsub.go
@@ -661,7 +661,7 @@ func (p *PubSub) validate(vals []*topicVal, src peer.ID, msg *Message) {
 	}
 
 	if len(vals) > 0 {
-		if !p.validateTopic(vals, msg) {
+		if !p.validateTopic(vals, src, msg) {
 			log.Warningf("message validation failed; dropping message from %s", src)
 			return
 		}
@@ -684,9 +684,9 @@ func (p *PubSub) validateSignature(msg *Message) bool {
 	return true
 }
 
-func (p *PubSub) validateTopic(vals []*topicVal, msg *Message) bool {
+func (p *PubSub) validateTopic(vals []*topicVal, src peer.ID, msg *Message) bool {
 	if len(vals) == 1 {
-		return p.validateSingleTopic(vals[0], msg)
+		return p.validateSingleTopic(vals[0], src, msg)
 	}
 
 	ctx, cancel := context.WithCancel(p.ctx)
@@ -703,7 +703,7 @@ loop:
 		select {
 		case val.validateThrottle <- struct{}{}:
 			go func(val *topicVal) {
-				rch <- val.validateMsg(ctx, msg)
+				rch <- val.validateMsg(ctx, src, msg)
 				<-val.validateThrottle
 			}(val)
 
@@ -729,13 +729,13 @@ loop:
 }
 
 // fast path for single topic validation that avoids the extra goroutine
-func (p *PubSub) validateSingleTopic(val *topicVal, msg *Message) bool {
+func (p *PubSub) validateSingleTopic(val *topicVal, src peer.ID, msg *Message) bool {
 	select {
 	case val.validateThrottle <- struct{}{}:
 		ctx, cancel := context.WithCancel(p.ctx)
 		defer cancel()
 
-		res := val.validateMsg(ctx, msg)
+		res := val.validateMsg(ctx, src, msg)
 		<-val.validateThrottle
 
 		return res
@@ -900,7 +900,7 @@ type topicVal struct {
 }
 
 // Validator is a function that validates a message.
-type Validator func(context.Context, *Message) bool
+type Validator func(context.Context, peer.ID, *Message) bool
 
 // ValidatorOpt is an option for RegisterTopicValidator.
 type ValidatorOpt func(addVal *addValReq) error
@@ -992,11 +992,11 @@ func (ps *PubSub) rmValidator(req *rmValReq) {
 	}
 }
 
-func (val *topicVal) validateMsg(ctx context.Context, msg *Message) bool {
+func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message) bool {
 	vctx, cancel := context.WithTimeout(ctx, val.validateTimeout)
 	defer cancel()
 
-	valid := val.validate(vctx, msg)
+	valid := val.validate(vctx, src, msg)
 	if !valid {
 		log.Debugf("validation failed for topic %s", val.topic)
 	}