Skip to content

Commit a53bb83

Browse files
committed
Support periodic reload of sampling strategies file
Signed-off-by: defool <defool@foxmail.com>
1 parent b99114e commit a53bb83

File tree

3 files changed

+137
-17
lines changed

3 files changed

+137
-17
lines changed

plugin/sampling/strategystore/static/options.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,33 @@ package static
1616

1717
import (
1818
"flag"
19+
"time"
1920

2021
"github.com/spf13/viper"
2122
)
2223

2324
const (
24-
samplingStrategiesFile = "sampling.strategies-file"
25+
samplingStrategiesFile = "sampling.strategies-file"
26+
samplingStrategiesReloadInterval = "sampling.strategies-reload-interval"
2527
)
2628

2729
// Options holds configuration for the static sampling strategy store.
2830
type Options struct {
2931
// StrategiesFile is the path for the sampling strategies file in JSON format
3032
StrategiesFile string
33+
// ReloadInterval is the time interval to check and reload sampling strategies file
34+
ReloadInterval time.Duration
3135
}
3236

3337
// AddFlags adds flags for Options
3438
func AddFlags(flagSet *flag.FlagSet) {
3539
flagSet.String(samplingStrategiesFile, "", "The path for the sampling strategies file in JSON format. See sampling documentation to see format of the file")
40+
flagSet.Duration(samplingStrategiesReloadInterval, 0, "Reload interval to check and reload sampling strategies file. Zero value means no checks (default 0s)")
3641
}
3742

3843
// InitFromViper initializes Options with properties from viper
3944
func (opts *Options) InitFromViper(v *viper.Viper) *Options {
4045
opts.StrategiesFile = v.GetString(samplingStrategiesFile)
46+
opts.ReloadInterval = v.GetDuration(samplingStrategiesReloadInterval)
4147
return opts
4248
}

plugin/sampling/strategystore/static/strategy_store.go

+77-14
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@ package static
1616

1717
import (
1818
"bytes"
19+
"context"
1920
"encoding/gob"
2021
"encoding/json"
2122
"fmt"
2223
"io/ioutil"
24+
"path/filepath"
25+
"sync/atomic"
26+
"time"
2327

2428
"go.uber.org/zap"
2529

@@ -30,31 +34,86 @@ import (
3034
type strategyStore struct {
3135
logger *zap.Logger
3236

33-
defaultStrategy *sampling.SamplingStrategyResponse
34-
serviceStrategies map[string]*sampling.SamplingStrategyResponse
37+
defaultStrategy atomic.Value
38+
serviceStrategies atomic.Value
39+
40+
ctx context.Context
41+
cancelFunc context.CancelFunc
3542
}
3643

3744
// NewStrategyStore creates a strategy store that holds static sampling strategies.
3845
func NewStrategyStore(options Options, logger *zap.Logger) (ss.StrategyStore, error) {
46+
ctx, cancelFunc := context.WithCancel(context.Background())
3947
h := &strategyStore{
40-
logger: logger,
41-
serviceStrategies: make(map[string]*sampling.SamplingStrategyResponse),
48+
logger: logger,
49+
ctx: ctx,
50+
cancelFunc: cancelFunc,
4251
}
52+
h.serviceStrategies.Store(make(map[string]*sampling.SamplingStrategyResponse))
4353
strategies, err := loadStrategies(options.StrategiesFile)
4454
if err != nil {
4555
return nil, err
4656
}
4757
h.parseStrategies(strategies)
58+
59+
if options.ReloadInterval > 0 {
60+
go h.autoUpdateStrategy(options.ReloadInterval, options.StrategiesFile)
61+
}
4862
return h, nil
4963
}
5064

5165
// GetSamplingStrategy implements StrategyStore#GetSamplingStrategy.
5266
func (h *strategyStore) GetSamplingStrategy(serviceName string) (*sampling.SamplingStrategyResponse, error) {
53-
if strategy, ok := h.serviceStrategies[serviceName]; ok {
67+
serviceStrategies, ok := h.serviceStrategies.Load().(map[string]*sampling.SamplingStrategyResponse)
68+
if !ok {
69+
return nil, fmt.Errorf("wrong type of serviceStrategies")
70+
}
71+
if strategy, ok := serviceStrategies[serviceName]; ok {
5472
return strategy, nil
5573
}
5674
h.logger.Debug("sampling strategy not found, using default", zap.String("service", serviceName))
57-
return h.defaultStrategy, nil
75+
return h.defaultStrategy.Load().(*sampling.SamplingStrategyResponse), nil
76+
}
77+
78+
// StopUpdateStrategy stops updating the strategy
79+
func (h *strategyStore) StopUpdateStrategy() {
80+
h.cancelFunc()
81+
}
82+
83+
func (h *strategyStore) autoUpdateStrategy(interval time.Duration, filePath string) {
84+
lastString := ""
85+
ticker := time.NewTicker(interval)
86+
defer ticker.Stop()
87+
for {
88+
select {
89+
case <-ticker.C:
90+
if currBytes, err := ioutil.ReadFile(filepath.Clean(filePath)); err == nil {
91+
currStr := string(currBytes)
92+
if lastString == currStr {
93+
continue
94+
}
95+
err := h.updateSamplingStrategy(currBytes)
96+
if err != nil {
97+
h.logger.Error("UpdateSamplingStrategy failed", zap.Error(err))
98+
}
99+
lastString = currStr
100+
} else {
101+
h.logger.Error("UpdateSamplingStrategy failed", zap.Error(err))
102+
}
103+
case <-h.ctx.Done():
104+
return
105+
}
106+
}
107+
}
108+
109+
func (h *strategyStore) updateSamplingStrategy(bytes []byte) error {
110+
var strategies strategies
111+
if err := json.Unmarshal(bytes, &strategies); err != nil {
112+
return fmt.Errorf("failed to unmarshal strategies: %w", err)
113+
}
114+
h.parseStrategies(&strategies)
115+
h.logger.Info("Updated strategy:" + string(bytes))
116+
return nil
58117
}
59118

60119
// TODO good candidate for a global util function
@@ -74,40 +133,44 @@ func loadStrategies(strategiesFile string) (*strategies, error) {
74133
}
75134

76135
func (h *strategyStore) parseStrategies(strategies *strategies) {
77-
h.defaultStrategy = defaultStrategyResponse()
136+
defaultStrategy := defaultStrategyResponse()
137+
h.defaultStrategy.Store(defaultStrategy)
78138
if strategies == nil {
79139
h.logger.Info("No sampling strategies provided, using defaults")
80140
return
81141
}
82142
if strategies.DefaultStrategy != nil {
83-
h.defaultStrategy = h.parseServiceStrategies(strategies.DefaultStrategy)
143+
defaultStrategy = h.parseServiceStrategies(strategies.DefaultStrategy)
84144
}
85145

86146
merge := true
87-
if h.defaultStrategy.OperationSampling == nil ||
88-
h.defaultStrategy.OperationSampling.PerOperationStrategies == nil {
147+
if defaultStrategy.OperationSampling == nil ||
148+
defaultStrategy.OperationSampling.PerOperationStrategies == nil {
89149
merge = false
90150
}
91151

152+
serviceStrategies := make(map[string]*sampling.SamplingStrategyResponse)
92153
for _, s := range strategies.ServiceStrategies {
93-
h.serviceStrategies[s.Service] = h.parseServiceStrategies(s)
154+
serviceStrategies[s.Service] = h.parseServiceStrategies(s)
94155

95156
// Merge with the default operation strategies, because only merging with
96157
// the default strategy has no effect on service strategies (the default strategy
97158
// is not merged with and only used as a fallback).
98-
opS := h.serviceStrategies[s.Service].OperationSampling
159+
opS := serviceStrategies[s.Service].OperationSampling
99160
if opS == nil {
100161
// Service has no per-operation strategies, so just reference the default settings.
101-
h.serviceStrategies[s.Service].OperationSampling = h.defaultStrategy.OperationSampling
162+
serviceStrategies[s.Service].OperationSampling = defaultStrategy.OperationSampling
102163
continue
103164
}
104165

105166
if merge {
106167
opS.PerOperationStrategies = mergePerOperationSamplingStrategies(
107168
opS.PerOperationStrategies,
108-
h.defaultStrategy.OperationSampling.PerOperationStrategies)
169+
defaultStrategy.OperationSampling.PerOperationStrategies)
109170
}
110171
}
172+
h.defaultStrategy.Store(defaultStrategy)
173+
h.serviceStrategies.Store(serviceStrategies)
111174
}
112175

113176
// mergePerOperationStrategies merges two operation strategies a and b, where a takes precedence over b.

plugin/sampling/strategystore/static/strategy_store_test.go

+53-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
package static
1616

1717
import (
18-
"fmt"
18+
"io/ioutil"
19+
"os"
20+
"strings"
1921
"testing"
22+
"time"
2023

2124
"github.com/stretchr/testify/assert"
2225
"github.com/stretchr/testify/require"
@@ -79,7 +82,7 @@ func TestPerOperationSamplingStrategies(t *testing.T) {
7982
os := s.OperationSampling
8083
assert.EqualValues(t, os.DefaultSamplingProbability, 0.8)
8184
require.Len(t, os.PerOperationStrategies, 4)
82-
fmt.Println(os)
85+
8386
assert.Equal(t, "op6", os.PerOperationStrategies[0].Operation)
8487
assert.EqualValues(t, 0.5, os.PerOperationStrategies[0].ProbabilisticSampling.SamplingRate)
8588
assert.Equal(t, "op1", os.PerOperationStrategies[1].Operation)
@@ -243,3 +246,51 @@ func TestDeepCopy(t *testing.T) {
243246
assert.False(t, copy == s)
244247
assert.EqualValues(t, copy, s)
245248
}
249+
250+
func TestAutoUpdateStrategy(t *testing.T) {
251+
// copy from fixtures/strategies.json
252+
tempFile, _ := ioutil.TempFile("", "for_go_test_*.json")
253+
tempFile.Close()
254+
255+
srcFile, dstFile := "fixtures/strategies.json", tempFile.Name()
256+
srcBytes, err := ioutil.ReadFile(srcFile)
257+
require.NoError(t, err)
258+
err = ioutil.WriteFile(dstFile, srcBytes, 0644)
259+
require.NoError(t, err)
260+
261+
interval := time.Millisecond * 10
262+
store, err := NewStrategyStore(Options{
263+
StrategiesFile: dstFile,
264+
ReloadInterval: interval,
265+
}, zap.NewNop())
266+
require.NoError(t, err)
267+
defer store.(*strategyStore).StopUpdateStrategy()
268+
269+
s, err := store.GetSamplingStrategy("foo")
270+
require.NoError(t, err)
271+
assert.EqualValues(t, makeResponse(sampling.SamplingStrategyType_PROBABILISTIC, 0.8), *s)
272+
273+
// update file
274+
newStr := strings.Replace(string(srcBytes), "0.8", "0.9", 1)
275+
err = ioutil.WriteFile(dstFile, []byte(newStr), 0644)
276+
require.NoError(t, err)
277+
278+
// wait for reload
279+
time.Sleep(interval * 2)
280+
281+
// verity reloading
282+
s, err = store.GetSamplingStrategy("foo")
283+
require.NoError(t, err)
284+
assert.EqualValues(t, makeResponse(sampling.SamplingStrategyType_PROBABILISTIC, 0.9), *s)
285+
286+
// check bad file content
287+
_ = ioutil.WriteFile(dstFile, []byte("bad value"), 0644)
288+
time.Sleep(interval * 2)
289+
290+
// check file not exist
291+
os.Remove(dstFile)
292+
293+
// wait for delete and update failed
294+
time.Sleep(interval * 2)
295+
296+
}

0 commit comments

Comments
 (0)