diff options
Diffstat (limited to 'matrix/rooms')
-rw-r--r-- | matrix/rooms/room.go | 61 |
1 files changed, 33 insertions, 28 deletions
diff --git a/matrix/rooms/room.go b/matrix/rooms/room.go index 53156a3..6334036 100644 --- a/matrix/rooms/room.go +++ b/matrix/rooms/room.go @@ -57,7 +57,6 @@ type UnreadMessage struct { Highlight bool } - // Room represents a single Matrix room. type Room struct { *mautrix.Room @@ -101,26 +100,7 @@ type Room struct { // The list of aliases. Directly fetched from the m.room.aliases state event. aliasesCache []string - // fetchHistoryLock is used to make sure multiple goroutines don't fetch - // history for this room at the same time. - fetchHistoryLock *sync.Mutex -} - -// LockHistory locks the history fetching mutex. -// If the mutex is nil, it will be created. -func (room *Room) LockHistory() { - if room.fetchHistoryLock == nil { - room.fetchHistoryLock = &sync.Mutex{} - } - room.fetchHistoryLock.Lock() -} - -// UnlockHistory unlocks the history fetching mutex. -// If the mutex is nil, this does nothing. -func (room *Room) UnlockHistory() { - if room.fetchHistoryLock != nil { - room.fetchHistoryLock.Unlock() - } + lock sync.RWMutex } func (room *Room) Load(path string) error { @@ -130,6 +110,8 @@ func (room *Room) Load(path string) error { } defer file.Close() dec := gob.NewDecoder(file) + room.lock.Lock() + defer room.lock.Unlock() return dec.Decode(room) } @@ -140,11 +122,15 @@ func (room *Room) Save(path string) error { } defer file.Close() enc := gob.NewEncoder(file) + room.lock.RLock() + defer room.lock.RUnlock() return enc.Encode(room) } // MarkRead clears the new message statuses on this room. func (room *Room) MarkRead(eventID string) bool { + room.lock.Lock() + defer room.lock.Unlock() if room.lastMarkedRead == eventID { return false } @@ -164,6 +150,8 @@ func (room *Room) MarkRead(eventID string) bool { } func (room *Room) UnreadCount() int { + room.lock.Lock() + defer room.lock.Unlock() if room.unreadCountCache == nil { room.unreadCountCache = new(int) for _, unreadMessage := range room.UnreadMessages { @@ -176,6 +164,8 @@ func (room *Room) UnreadCount() int { } func (room *Room) Highlighted() bool { + room.lock.Lock() + defer room.lock.Unlock() if room.highlightCache == nil { room.highlightCache = new(bool) for _, unreadMessage := range room.UnreadMessages { @@ -193,6 +183,8 @@ func (room *Room) HasNewMessages() bool { } func (room *Room) AddUnread(eventID string, counted, highlight bool) { + room.lock.Lock() + defer room.lock.Unlock() room.UnreadMessages = append(room.UnreadMessages, UnreadMessage{ EventID: eventID, Counted: counted, @@ -213,6 +205,8 @@ func (room *Room) AddUnread(eventID string, counted, highlight bool) { } func (room *Room) Tags() []RoomTag { + room.lock.RLock() + defer room.lock.RUnlock() if len(room.RawTags) == 0 { if room.IsDirect { return []RoomTag{{"net.maunium.gomuks.fake.direct", "0.5"}} @@ -225,6 +219,8 @@ func (room *Room) Tags() []RoomTag { // UpdateState updates the room's current state with the given Event. This will clobber events based // on the type/state_key combination. func (room *Room) UpdateState(event *mautrix.Event) { + room.lock.Lock() + defer room.lock.Unlock() _, exists := room.State[event.Type] if !exists { room.State[event.Type] = make(map[string]*mautrix.Event) @@ -269,13 +265,15 @@ func (room *Room) UpdateState(event *mautrix.Event) { // GetStateEvent returns the state event for the given type/state_key combo, or nil. func (room *Room) GetStateEvent(eventType mautrix.EventType, stateKey string) *mautrix.Event { + room.lock.RLock() + defer room.lock.RUnlock() stateEventMap, _ := room.State[eventType] event, _ := stateEventMap[stateKey] return event } -// GetStateEvents returns the state events for the given type. -func (room *Room) GetStateEvents(eventType mautrix.EventType) map[string]*mautrix.Event { +// getStateEvents returns the state events for the given type. +func (room *Room) getStateEvents(eventType mautrix.EventType) map[string]*mautrix.Event { stateEventMap, _ := room.State[eventType] return stateEventMap } @@ -309,11 +307,13 @@ func (room *Room) GetCanonicalAlias() string { // GetAliases returns the list of aliases that point to this room. func (room *Room) GetAliases() []string { if room.aliasesCache == nil { - aliasEvents := room.GetStateEvents(mautrix.StateAliases) + room.lock.RLock() + aliasEvents := room.getStateEvents(mautrix.StateAliases) room.aliasesCache = []string{} for _, event := range aliasEvents { room.aliasesCache = append(room.aliasesCache, event.Content.Aliases...) } + room.lock.RUnlock() } return room.aliasesCache } @@ -394,7 +394,8 @@ func (room *Room) GetTitle() string { // createMemberCache caches all member events into a easily processable MXID -> *Member map. func (room *Room) createMemberCache() map[string]*mautrix.Member { cache := make(map[string]*mautrix.Member) - events := room.GetStateEvents(mautrix.StateMember) + room.lock.RLock() + events := room.getStateEvents(mautrix.StateMember) room.firstMemberCache = nil if events != nil { for userID, event := range events { @@ -411,7 +412,10 @@ func (room *Room) createMemberCache() map[string]*mautrix.Member { } } } + room.lock.RUnlock() + room.lock.Lock() room.memberCache = cache + room.lock.Unlock() return cache } @@ -432,7 +436,9 @@ func (room *Room) GetMember(userID string) *mautrix.Member { if len(room.memberCache) == 0 { room.createMemberCache() } + room.lock.RLock() member, _ := room.memberCache[userID] + room.lock.RUnlock() return member } @@ -444,8 +450,7 @@ func (room *Room) GetSessionOwner() string { // NewRoom creates a new Room with the given ID func NewRoom(roomID, owner string) *Room { return &Room{ - Room: mautrix.NewRoom(roomID), - fetchHistoryLock: &sync.Mutex{}, - SessionUserID: owner, + Room: mautrix.NewRoom(roomID), + SessionUserID: owner, } } |