Skip to content

Commit 9ec7f43

Browse files
authored
Merge pull request #2103 from anywhy/fix_session_last_price
FIX: [session] fix last price concurrent read and write error
2 parents 69e394f + d31cb5a commit 9ec7f43

File tree

2 files changed

+70
-9
lines changed

2 files changed

+70
-9
lines changed

pkg/bbgo/session.go

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ type ExchangeSession struct {
220220

221221
lastPrices map[string]fixedpoint.Value
222222
lastPriceUpdatedAt time.Time
223+
lastPricesMutex sync.Mutex
223224

224225
// marketDataStores contains the market data store of each market
225226
marketDataStores map[string]*types.MarketDataStore
@@ -571,20 +572,20 @@ func (session *ExchangeSession) Init(ctx context.Context, environ *Environment)
571572
session.startPrices[kline.Symbol] = kline.Open
572573
}
573574

574-
session.lastPrices[kline.Symbol] = session.MarketDataStream.(*types.HeikinAshiStream).LastOrigin[kline.Symbol][kline.Interval].Close
575+
session.setLastPrice(kline.Symbol, session.MarketDataStream.(*types.HeikinAshiStream).LastOrigin[kline.Symbol][kline.Interval].Close)
575576
})
576577
} else {
577578
session.MarketDataStream.OnKLineClosed(func(kline types.KLine) {
578579
if _, ok := session.startPrices[kline.Symbol]; !ok {
579-
session.startPrices[kline.Symbol] = kline.Open
580+
session.setLastPrice(kline.Symbol, kline.Open)
580581
}
581582

582-
session.lastPrices[kline.Symbol] = kline.Close
583+
session.setLastPrice(kline.Symbol, kline.Close)
583584
})
584585
}
585586

586587
session.MarketDataStream.OnMarketTrade(func(trade types.Trade) {
587-
session.lastPrices[trade.Symbol] = trade.Price
588+
session.setLastPrice(trade.Symbol, trade.Price)
588589
})
589590

590591
// session-wide max borrowable updating worker
@@ -738,7 +739,7 @@ func (session *ExchangeSession) initSymbol(ctx context.Context, environ *Environ
738739
// update last prices by the given kline
739740
lastKLine := kLines[len(kLines)-1]
740741
if interval == minInterval {
741-
session.lastPrices[symbol] = lastKLine.Close
742+
session.setLastPrice(symbol, lastKLine.Close)
742743
}
743744

744745
for _, k := range kLines {
@@ -859,15 +860,21 @@ func (session *ExchangeSession) StartPrice(symbol string) (price fixedpoint.Valu
859860
}
860861

861862
func (session *ExchangeSession) LastPrice(symbol string) (price fixedpoint.Value, ok bool) {
863+
session.lastPricesMutex.Lock()
864+
defer session.lastPricesMutex.Unlock()
865+
862866
price, ok = session.lastPrices[symbol]
863867
return price, ok
864868
}
865869

866870
func (session *ExchangeSession) AllLastPrices() map[string]fixedpoint.Value {
867-
return session.lastPrices
871+
return session.LastPrices()
868872
}
869873

870874
func (session *ExchangeSession) LastPrices() map[string]fixedpoint.Value {
875+
session.lastPricesMutex.Lock()
876+
defer session.lastPricesMutex.Unlock()
877+
871878
return session.lastPrices
872879
}
873880

@@ -965,12 +972,12 @@ func (session *ExchangeSession) UpdatePrices(ctx context.Context, currencies []s
965972
// map things like BTCUSDT = {price}
966973
if market, ok := markets[k]; ok {
967974
if currency2.IsFiatCurrency(market.BaseCurrency) {
968-
session.lastPrices[k] = validPrice.Div(fixedpoint.One)
975+
session.setLastPrice(k, validPrice.Div(fixedpoint.One))
969976
} else {
970-
session.lastPrices[k] = validPrice
977+
session.setLastPrice(k, validPrice)
971978
}
972979
} else {
973-
session.lastPrices[k] = v.Last
980+
session.setLastPrice(k, v.Last)
974981
}
975982

976983
if v.Time.After(lastTime) {
@@ -1330,3 +1337,10 @@ func (session *ExchangeSession) UpdateMaxBorrowable(ctx context.Context) {
13301337
}
13311338
session.marginInfoUpdater.UpdateMaxBorrowable(ctx)
13321339
}
1340+
1341+
func (session *ExchangeSession) setLastPrice(symbol string, price fixedpoint.Value) {
1342+
session.lastPricesMutex.Lock()
1343+
defer session.lastPricesMutex.Unlock()
1344+
1345+
session.lastPrices[symbol] = price
1346+
}

pkg/bbgo/session_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,48 @@
11
package bbgo
2+
3+
import (
4+
"sync"
5+
"testing"
6+
7+
"github.com/c9s/bbgo/pkg/fixedpoint"
8+
)
9+
10+
func TestExchangeSession_LastPricesMutex_ConcurrentAccess(t *testing.T) {
11+
session := &ExchangeSession{
12+
lastPrices: make(map[string]fixedpoint.Value),
13+
}
14+
15+
var wg sync.WaitGroup
16+
symbol := "BTCUSDT"
17+
writeCount := 50
18+
19+
// Concurrently write to lastPrices
20+
for i := range writeCount {
21+
wg.Add(1)
22+
go func(i int) {
23+
defer wg.Done()
24+
price := fixedpoint.NewFromInt(int64(i))
25+
session.setLastPrice(symbol, price)
26+
}(i)
27+
}
28+
29+
// Concurrently read from lastPrices
30+
for range writeCount {
31+
wg.Add(1)
32+
go func() {
33+
defer wg.Done()
34+
_, _ = session.LastPrice(symbol)
35+
}()
36+
}
37+
38+
wg.Wait()
39+
40+
price, ok := session.LastPrice(symbol)
41+
if !ok {
42+
t.Fatalf("expected price for symbol %s", symbol)
43+
}
44+
45+
if price.Int64() < 0 || price.Int64() >= int64(writeCount) {
46+
t.Errorf("unexpected price %d, should be in [0, %d)", price.Int64(), writeCount)
47+
}
48+
}

0 commit comments

Comments
 (0)