Skip to content

Commit 3c3985a

Browse files
authored
Merge pull request #155 from ipfs/fix/pointer-reflection
typed encoder: improve pointer reflection
2 parents 8f644c8 + c99a709 commit 3c3985a

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

encoding.go

+23-4
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,35 @@ func MakeTypedEncoder(f interface{}) func(*Request) func(io.Writer) Encoder {
9191
panic("MakeTypedEncoder must receive a function matching func(*Request, io.Writer, ...)")
9292
}
9393

94-
valType := t.In(2)
95-
valTypePtr := reflect.PtrTo(valType)
94+
var (
95+
valType, valTypeAlt reflect.Type
96+
)
97+
98+
valType = t.In(2)
99+
valTypeIsPtr := valType.Kind() == reflect.Ptr
100+
if valTypeIsPtr {
101+
valTypeAlt = valType.Elem()
102+
} else {
103+
valTypeAlt = reflect.PtrTo(valType)
104+
}
96105

97106
return MakeEncoder(func(req *Request, w io.Writer, i interface{}) error {
98107
iType := reflect.TypeOf(i)
99108
iValue := reflect.ValueOf(i)
100109
switch iType {
101110
case valType:
102-
case valTypePtr:
103-
iValue = iValue.Elem()
111+
case valTypeAlt:
112+
if valTypeIsPtr {
113+
if iValue.CanAddr() {
114+
iValue = iValue.Addr()
115+
} else {
116+
oldValue := iValue
117+
iValue = reflect.New(iType)
118+
iValue.Elem().Set(oldValue)
119+
}
120+
} else {
121+
iValue = iValue.Elem()
122+
}
104123
default:
105124
return fmt.Errorf("unexpected type %T, expected %v", i, valType)
106125
}

encoding_test.go

+25
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,31 @@ func TestMakeTypedEncoderByValue(t *testing.T) {
6161
}
6262
}
6363

64+
func TestMakeTypedEncoderByPointer(t *testing.T) {
65+
expErr := fmt.Errorf("command fooTestObj failed")
66+
f := MakeTypedEncoder(func(req *Request, w io.Writer, v *fooTestObj) error {
67+
if v.Good {
68+
return nil
69+
}
70+
return expErr
71+
})
72+
73+
req := &Request{}
74+
75+
encoderFunc := f(req)
76+
77+
buf := new(bytes.Buffer)
78+
encoder := encoderFunc(buf)
79+
80+
if err := encoder.Encode(fooTestObj{true}); err != nil {
81+
t.Fatal(err)
82+
}
83+
84+
if err := encoder.Encode(fooTestObj{false}); err != expErr {
85+
t.Fatal("expected: ", expErr)
86+
}
87+
}
88+
6489
func TestMakeTypedEncoderArrays(t *testing.T) {
6590
f := MakeTypedEncoder(func(req *Request, w io.Writer, v []fooTestObj) error {
6691
if len(v) != 2 {

0 commit comments

Comments
 (0)