diff --git a/platform/bill.go b/platform/bill.go index 32ba88b..ae0d783 100644 --- a/platform/bill.go +++ b/platform/bill.go @@ -19,27 +19,42 @@ import ( "errors" "fmt" "sync" + "time" "github.com/sacloud/iaas-api-go" "github.com/sacloud/iaas-api-go/types" ) +var ( + BillAPIUpdateHourJST = 4 + BillAPIUpdateMinuteJST = 30 +) + // BillClient calls SakuraCloud bill API type BillClient interface { Read(context.Context) (*iaas.Bill, error) } func getBillClient(caller iaas.APICaller) BillClient { - return &billClient{caller: caller} + return &billClient{ + caller: caller, + cache: newCache(30 * time.Minute), + } } type billClient struct { caller iaas.APICaller accountID types.ID once sync.Once + cache *cache } func (c *billClient) Read(ctx context.Context) (*iaas.Bill, error) { + ca := c.cache.get() + if ca != nil { + return ca.(*iaas.Bill), nil + } + var err error c.once.Do(func() { var auth *iaas.AuthStatus @@ -74,5 +89,36 @@ func (c *billClient) Read(ctx context.Context) (*iaas.Bill, error) { bill = b } } + + n, err := nextCacheExpiresAt() + if err != nil { + return nil, err + } + err = c.cache.set(bill, n) + if err != nil { + return nil, err + } + return bill, nil } + +// キャッシュの有効期限を算出する +// +// Billing APIは1日1回 AM4:30 (JST) にデータが更新される。 +// このため、現在時刻がAM4:30 (JST) よりも早ければ当日のAM4:30 (JST)、 +// 現在時刻がAM4:30 (JST) よりも遅ければ翌日のAM4:30 (JST) を有効期限として扱う。 +func nextCacheExpiresAt() (time.Time, error) { + jst, err := time.LoadLocation("Asia/Tokyo") + if err != nil { + return time.Time{}, err + } + + // 実行環境のタイムゾーンは不定のためJSTを基準にする + now := time.Now().In(jst) + expiresAt := time.Date(now.Year(), now.Month(), now.Day(), BillAPIUpdateHourJST, BillAPIUpdateMinuteJST, 0, 0, jst) + if now.Equal(expiresAt) || now.After(expiresAt) { + expiresAt = expiresAt.Add(24 * time.Hour) + } + + return expiresAt, nil +} diff --git a/platform/cache.go b/platform/cache.go new file mode 100644 index 0000000..00c0dae --- /dev/null +++ b/platform/cache.go @@ -0,0 +1,50 @@ +package platform + +import ( + "errors" + "sync" + "time" +) + +type cache struct { + cleanupInterval time.Duration + expiresAt time.Time + item any + mu sync.Mutex +} + +func newCache(cleanupInterval time.Duration) *cache { + c := &cache{ + cleanupInterval: cleanupInterval, + } + + return c +} + +func (c *cache) set(item any, expiresAt time.Time) error { + if item == nil { + return errors.New("item is not set") + } + if expiresAt.IsZero() { + return errors.New("expiresAt is not set") + } + + c.mu.Lock() + defer c.mu.Unlock() + + c.item = item + c.expiresAt = expiresAt + + return nil +} + +func (c *cache) get() any { + c.mu.Lock() + defer c.mu.Unlock() + + if time.Now().After(c.expiresAt) { + return nil + } + + return c.item +} diff --git a/platform/cache_test.go b/platform/cache_test.go new file mode 100644 index 0000000..c550b4b --- /dev/null +++ b/platform/cache_test.go @@ -0,0 +1,69 @@ +package platform + +import ( + "testing" + "time" +) + +func TestNewCache(t *testing.T) { + cleanupInterval := 10 * time.Minute + cache := newCache(cleanupInterval) + if cache.cleanupInterval != cleanupInterval { + t.Errorf("expected cleanupInterval %v, got %v", cleanupInterval, cache.cleanupInterval) + } +} + +func TestCache_Set(t *testing.T) { + cleanupInterval := 10 * time.Minute + cache := newCache(cleanupInterval) + + item := "dummy_item" + expiresAt := time.Now().Add(1 * time.Hour) + err := cache.set(item, expiresAt) + if err != nil { + t.Error(err) + } + if cache.item != item { + t.Errorf("item %v, got %v", item, cache.item) + } + if cache.expiresAt != expiresAt { + t.Errorf("expiresAt %v, got %v", expiresAt, cache.expiresAt) + } +} + +func TestCache_Get_ItemNotExpired(t *testing.T) { + cleanupInterval := 10 * time.Minute + cache := newCache(cleanupInterval) + + item := "dummy_item" + expiresAt := time.Now().Add(1 * time.Hour) + err := cache.set(item, expiresAt) + if err != nil { + t.Error(err) + } + + cachedItem := cache.get() + if cachedItem != item { + t.Errorf("cachedItem %v, got %v", item, cachedItem) + } +} + +func TestCache_Get_ItemExpired(t *testing.T) { + cleanupInterval := 1 * time.Second + cache := newCache(cleanupInterval) + + item := "dummy_item" + expiresAt := time.Now().Add(1 * time.Second) + err := cache.set(item, expiresAt) + if err != nil { + t.Error(err) + } + + // キャッシュの期限切れと削除が行われるのを待つ + time.Sleep(2 * time.Second) + + cachedItem := cache.get() + if cachedItem != nil { + t.Errorf("cached item not cleared, got %v", cachedItem) + } +}