Skip to content

Commit e6113da

Browse files
authored
Fix SQL NPE after connection is closed (#3829)
* fix npe for closed sql connections * connection test * cleanup * unexport * comments
1 parent 1c638ac commit e6113da

File tree

4 files changed

+117
-3
lines changed

4 files changed

+117
-3
lines changed

common/persistence/sql/factory.go

+1-3
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,7 @@ func (c *DbConn) Close() error {
187187
defer c.Unlock()
188188
c.refCnt--
189189
if c.refCnt == 0 {
190-
err := c.DB.Close()
191-
c.DB = nil
192-
return err
190+
return c.DB.Close()
193191
}
194192
return nil
195193
}

common/persistence/tests/mysql_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -574,3 +574,11 @@ func TestMySQLVisibilitySuite(t *testing.T) {
574574
s := sqltests.NewVisibilitySuite(t, store)
575575
suite.Run(t, s)
576576
}
577+
578+
func TestMySQLClosedConnectionError(t *testing.T) {
579+
testData, tearDown := setUpMySQLTest(t)
580+
defer tearDown()
581+
582+
s := newConnectionSuite(t, testData.Factory)
583+
suite.Run(t, s)
584+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// The MIT License
2+
//
3+
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
4+
//
5+
// Copyright (c) 2020 Uber Technologies, Inc.
6+
//
7+
// Permission is hereby granted, free of charge, to any person obtaining a copy
8+
// of this software and associated documentation files (the "Software"), to deal
9+
// in the Software without restriction, including without limitation the rights
10+
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
// copies of the Software, and to permit persons to whom the Software is
12+
// furnished to do so, subject to the following conditions:
13+
//
14+
// The above copyright notice and this permission notice shall be included in
15+
// all copies or substantial portions of the Software.
16+
//
17+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23+
// THE SOFTWARE.
24+
25+
package tests
26+
27+
import (
28+
"context"
29+
"math/rand"
30+
"testing"
31+
"time"
32+
33+
"github.com/stretchr/testify/require"
34+
"github.com/stretchr/testify/suite"
35+
p "go.temporal.io/server/common/persistence"
36+
"go.temporal.io/server/common/persistence/serialization"
37+
"go.temporal.io/server/common/persistence/sql"
38+
)
39+
40+
type (
41+
connectionSuite struct {
42+
suite.Suite
43+
*require.Assertions
44+
45+
factory *sql.Factory
46+
}
47+
)
48+
49+
func newConnectionSuite(
50+
t *testing.T,
51+
factory *sql.Factory,
52+
) *connectionSuite {
53+
return &connectionSuite{
54+
Assertions: require.New(t),
55+
factory: factory,
56+
}
57+
}
58+
59+
func (s *connectionSuite) SetupSuite() {
60+
61+
}
62+
63+
func (s *connectionSuite) TearDownSuite() {
64+
65+
}
66+
67+
func (s *connectionSuite) SetupTest() {
68+
s.Assertions = require.New(s.T())
69+
}
70+
71+
func (s *connectionSuite) TearDownTest() {
72+
73+
}
74+
75+
// Tests that SQL operations do not panic if the underlying connection has been closed and that the persistence layer
76+
// returns a useful error message.
77+
// Currently only run against MySQL and Postgresql (SQLite always maintains at least one open connection)
78+
func (s *connectionSuite) TestClosedConnectionError() {
79+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
80+
defer cancel()
81+
82+
shardID := (int32)(1)
83+
rangeID := rand.Int63()
84+
shardInfo := RandomShardInfo(shardID, rangeID)
85+
86+
store, err := s.factory.NewShardStore()
87+
s.NoError(err)
88+
89+
store.Close() // Connection will be closed by this call
90+
manager := p.NewShardManager(store, serialization.NewSerializer())
91+
92+
resp, err := manager.GetOrCreateShard(ctx, &p.GetOrCreateShardRequest{
93+
ShardID: shardID,
94+
InitialShardInfo: shardInfo,
95+
})
96+
97+
s.Nil(resp)
98+
s.Error(err)
99+
s.ErrorContains(err, "closed")
100+
}

common/persistence/tests/postgresql_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -581,3 +581,11 @@ func TestPostgreSQLVisibilitySuite(t *testing.T) {
581581
s := sqltests.NewVisibilitySuite(t, store)
582582
suite.Run(t, s)
583583
}
584+
585+
func TestPostgreSQLClosedConnectionError(t *testing.T) {
586+
testData, tearDown := setUpPostgreSQLTest(t)
587+
defer tearDown()
588+
589+
s := newConnectionSuite(t, testData.Factory)
590+
suite.Run(t, s)
591+
}

0 commit comments

Comments
 (0)