diff options
author | Tulir Asokan <tulir@maunium.net> | 2019-04-05 23:44:17 +0300 |
---|---|---|
committer | Tulir Asokan <tulir@maunium.net> | 2019-04-05 23:44:17 +0300 |
commit | 7ad2103f8f2c9b7e3d12554634a68db973a05b36 (patch) | |
tree | e08c3eb377411bf6954ef24ff9205202f07e7296 /matrix | |
parent | 535fbbb4f7703845bb25484f6eb67b1389f2dd61 (diff) |
Move history storage to matrix package. Fixes #90
Diffstat (limited to 'matrix')
-rw-r--r-- | matrix/history.go | 247 | ||||
-rw-r--r-- | matrix/matrix.go | 54 | ||||
-rw-r--r-- | matrix/rooms/room.go | 9 |
3 files changed, 305 insertions, 5 deletions
diff --git a/matrix/history.go b/matrix/history.go new file mode 100644 index 0000000..1b99125 --- /dev/null +++ b/matrix/history.go @@ -0,0 +1,247 @@ +// gomuks - A terminal Matrix client written in Go. +// Copyright (C) 2019 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <https://www.gnu.org/licenses/>. + +package matrix + +import ( + "bytes" + "encoding/binary" + "encoding/gob" + "sync" + + bolt "go.etcd.io/bbolt" + + "maunium.net/go/gomuks/matrix/rooms" + "maunium.net/go/mautrix" +) + +type HistoryManager struct { + sync.Mutex + + db *bolt.DB + + historyEndPtr map[*rooms.Room]uint64 + historyLoadPtr map[*rooms.Room]uint64 +} + +var bucketRoomStreams = []byte("room_streams") +var bucketRoomEventIDs = []byte("room_event_ids") +var bucketStreamPointers = []byte("room_stream_pointers") + +const halfUint64 = ^uint64(0) >> 1 + +func NewHistoryManager(dbPath string) (*HistoryManager, error) { + hm := &HistoryManager{ + historyEndPtr: make(map[*rooms.Room]uint64), + historyLoadPtr: make(map[*rooms.Room]uint64), + } + db, err := bolt.Open(dbPath, 0600, nil) + if err != nil { + return nil, err + } + err = db.Update(func(tx *bolt.Tx) error { + _, err = tx.CreateBucketIfNotExists(bucketRoomStreams) + if err != nil { + return err + } + _, err = tx.CreateBucketIfNotExists(bucketRoomEventIDs) + if err != nil { + return err + } + _, err = tx.CreateBucketIfNotExists(bucketStreamPointers) + if err != nil { + return err + } + return nil + }) + if err != nil { + return nil, err + } + hm.db = db + return hm, nil +} + +func (hm *HistoryManager) Close() error { + return hm.db.Close() +} + +func (hm *HistoryManager) Get(room *rooms.Room, eventID string) (event *mautrix.Event, err error) { + err = hm.db.View(func(tx *bolt.Tx) error { + rid := []byte(room.ID) + eventIDs := tx.Bucket(bucketRoomEventIDs).Bucket(rid) + if eventIDs == nil { + return nil + } + streamIndex := eventIDs.Get([]byte(eventID)) + if streamIndex == nil { + return nil + } + stream := tx.Bucket(bucketRoomStreams).Bucket(rid) + eventData := stream.Get(streamIndex) + var umErr error + event, umErr = unmarshalEvent(eventData) + return umErr + }) + return +} + +func (hm *HistoryManager) Append(room *rooms.Room, events []*mautrix.Event) error { + return hm.store(room, events, true) +} + +func (hm *HistoryManager) Prepend(room *rooms.Room, events []*mautrix.Event) error { + return hm.store(room, events, false) +} + +func (hm *HistoryManager) store(room *rooms.Room, events []*mautrix.Event, append bool) error { + hm.Lock() + defer hm.Unlock() + err := hm.db.Update(func(tx *bolt.Tx) error { + streamPointers := tx.Bucket(bucketStreamPointers) + rid := []byte(room.ID) + stream, err := tx.Bucket(bucketRoomStreams).CreateBucketIfNotExists(rid) + if err != nil { + return err + } + eventIDs, err := tx.Bucket(bucketRoomEventIDs).CreateBucketIfNotExists(rid) + if err != nil { + return err + } + if stream.Sequence() < halfUint64 { + // The sequence counter (i.e. the future) the part after 2^63, i.e. the second half of uint64 + // We set it to -1 because NextSequence will increment it by one. + err = stream.SetSequence(halfUint64 - 1) + if err != nil { + return err + } + } + if append { + ptrStart, err := stream.NextSequence() + if err != nil { + return err + } + for i, event := range events { + if err := put(stream, eventIDs, event, ptrStart+uint64(i)); err != nil { + return err + } + } + err = stream.SetSequence(ptrStart + uint64(len(events)) - 1) + if err != nil { + return err + } + } else { + ptrStart, ok := hm.historyEndPtr[room] + if !ok { + ptrStartRaw := streamPointers.Get(rid) + if ptrStartRaw != nil { + ptrStart = btoi(ptrStartRaw) + } else { + ptrStart = halfUint64 - 1 + } + } + eventCount := uint64(len(events)) + for i, event := range events { + if err := put(stream, eventIDs, event, -ptrStart-uint64(i)); err != nil { + return err + } + } + hm.historyEndPtr[room] = ptrStart + eventCount + err := streamPointers.Put(rid, itob(ptrStart+eventCount)) + if err != nil { + return err + } + } + + return nil + }) + return err +} + +func (hm *HistoryManager) Load(room *rooms.Room, num int) (events []*mautrix.Event, err error) { + hm.Lock() + defer hm.Unlock() + err = hm.db.View(func(tx *bolt.Tx) error { + rid := []byte(room.ID) + stream := tx.Bucket(bucketRoomStreams).Bucket(rid) + if stream == nil { + return nil + } + ptrStart, ok := hm.historyLoadPtr[room] + if !ok { + ptrStart = stream.Sequence() + } + c := stream.Cursor() + k, v := c.Seek(itob(ptrStart - uint64(num))) + ptrStartFound := btoi(k) + if k == nil || ptrStartFound >= ptrStart { + return nil + } + hm.historyLoadPtr[room] = ptrStartFound - 1 + for ; k != nil && btoi(k) < ptrStart; k, v = c.Next() { + event, parseError := unmarshalEvent(v) + if parseError != nil { + return parseError + } + events = append(events, event) + } + return nil + }) + // Reverse array because we read/append the history in reverse order. + i := 0 + j := len(events) - 1 + for i < j { + events[i], events[j] = events[j], events[i] + i++ + j-- + } + return +} + +func itob(v uint64) []byte { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, v) + return b +} + +func btoi(b []byte) uint64 { + return binary.BigEndian.Uint64(b) +} + +func marshalEvent(event *mautrix.Event) ([]byte, error) { + var buf bytes.Buffer + err := gob.NewEncoder(&buf).Encode(event) + return buf.Bytes(), err +} + +func unmarshalEvent(data []byte) (*mautrix.Event, error) { + event := &mautrix.Event{} + return event, gob.NewDecoder(bytes.NewReader(data)).Decode(event) +} + +func put(streams, eventIDs *bolt.Bucket, event *mautrix.Event, key uint64) error { + data, err := marshalEvent(event) + if err != nil { + return err + } + keyBytes := itob(key) + if err = streams.Put(keyBytes, data); err != nil { + return err + } + if err = eventIDs.Put([]byte(event.ID), keyBytes); err != nil { + return err + } + return nil +} diff --git a/matrix/matrix.go b/matrix/matrix.go index 9527451..6985659 100644 --- a/matrix/matrix.go +++ b/matrix/matrix.go @@ -50,6 +50,7 @@ type Container struct { gmx ifc.Gomuks ui ifc.GomuksUI config *config.Config + history *HistoryManager running bool stop chan bool @@ -102,6 +103,11 @@ func (c *Container) InitClient() error { } c.client.Logger = mxLogger{} + c.history, err = NewHistoryManager(c.config.HistoryPath) + if err != nil { + return err + } + allowInsecure := len(os.Getenv("GOMUKS_ALLOW_INSECURE_CONNECTIONS")) > 0 if allowInsecure { c.client.Client = &http.Client{ @@ -158,6 +164,11 @@ func (c *Container) Stop() { debug.Print("Stopping Matrix container...") c.stop <- true c.client.StopSync() + debug.Print("Closing history manager...") + err := c.history.Close() + if err != nil { + debug.Print("Error closing history manager:", err) + } } } @@ -281,6 +292,11 @@ func (c *Container) HandleMessage(source EventSource, evt *mautrix.Event) { return } + err := c.history.Append(roomView.MxRoom(), []*mautrix.Event{evt}) + if err != nil { + debug.Printf("Failed to add event %s to history: %v", evt.ID, err) + } + message := mainView.ParseEvent(roomView, evt) if message != nil { roomView.AddMessage(message, ifc.AppendMessage) @@ -537,12 +553,42 @@ func (c *Container) LeaveRoom(roomID string) error { } // GetHistory fetches room history. -func (c *Container) GetHistory(roomID, prevBatch string, limit int) ([]*mautrix.Event, string, error) { - resp, err := c.client.Messages(roomID, prevBatch, "", 'b', limit) +func (c *Container) GetHistory(room *rooms.Room, limit int) ([]*mautrix.Event, error) { + events, err := c.history.Load(room, limit) if err != nil { - return nil, "", err + return nil, err + } + if len(events) > 0 { + debug.Printf("Loaded %d events for %s from local cache", len(events), room.ID) + return events, nil + } + resp, err := c.client.Messages(room.ID, room.PrevBatch, "", 'b', limit) + if err != nil { + return nil, err + } + if len(resp.Chunk) > 0 { + err = c.history.Prepend(room, resp.Chunk) + if err != nil { + return nil, err + } + } + room.PrevBatch = resp.End + debug.Printf("Loaded %d events for %s from server from %s to %s", len(resp.Chunk), room.ID, resp.Start, resp.End) + return resp.Chunk, nil +} + +func (c *Container) GetEvent(room *rooms.Room, eventID string) (*mautrix.Event, error) { + event, err := c.history.Get(room, eventID) + if event != nil || err != nil { + debug.Printf("Found event %s in local cache", eventID) + return event, err + } + event, err = c.client.GetEvent(room.ID, eventID) + if err != nil { + return nil, err } - return resp.Chunk, resp.End, nil + debug.Printf("Loaded event %s from server", eventID) + return event, nil } // GetRoom gets the room instance stored in the session. diff --git a/matrix/rooms/room.go b/matrix/rooms/room.go index 16d4a9e..47f5602 100644 --- a/matrix/rooms/room.go +++ b/matrix/rooms/room.go @@ -57,6 +57,7 @@ type UnreadMessage struct { Highlight bool } + // Room represents a single Matrix room. type Room struct { *mautrix.Room @@ -74,6 +75,7 @@ type Room struct { UnreadMessages []UnreadMessage unreadCountCache *int highlightCache *bool + lastMarkedRead string // Whether or not this room is marked as a direct chat. IsDirect bool @@ -142,7 +144,11 @@ func (room *Room) Save(path string) error { } // MarkRead clears the new message statuses on this room. -func (room *Room) MarkRead(eventID string) { +func (room *Room) MarkRead(eventID string) bool { + if room.lastMarkedRead == eventID { + return false + } + room.lastMarkedRead = eventID readToIndex := -1 for index, unreadMessage := range room.UnreadMessages { if unreadMessage.EventID == eventID { @@ -154,6 +160,7 @@ func (room *Room) MarkRead(eventID string) { room.highlightCache = nil room.unreadCountCache = nil } + return true } func (room *Room) UnreadCount() int { |