diff options
author | Tulir Asokan <tulir@maunium.net> | 2020-05-06 20:06:35 +0300 |
---|---|---|
committer | Tulir Asokan <tulir@maunium.net> | 2020-05-06 20:06:35 +0300 |
commit | 5b3e91524e000fe07fe30eff70061c9dd796014e (patch) | |
tree | 47eac32a4ce211b7d240841fe35e1a0cb745680c /matrix | |
parent | 96bb87e8ac8f45d56d487ea6c16d67f057d97e1f (diff) | |
parent | ebdfe914283fb91204ca8512a0a73a78fe41998f (diff) |
Merge branch 'e2ee'
Diffstat (limited to 'matrix')
-rw-r--r-- | matrix/crypto.go | 62 | ||||
-rw-r--r-- | matrix/matrix.go | 148 | ||||
-rw-r--r-- | matrix/nocrypto.go | 13 | ||||
-rw-r--r-- | matrix/rooms/room.go | 13 | ||||
-rw-r--r-- | matrix/rooms/roomcache.go | 24 | ||||
-rw-r--r-- | matrix/sync.go | 46 |
6 files changed, 265 insertions, 41 deletions
diff --git a/matrix/crypto.go b/matrix/crypto.go new file mode 100644 index 0000000..8eab355 --- /dev/null +++ b/matrix/crypto.go @@ -0,0 +1,62 @@ +// gomuks - A terminal Matrix client written in Go. +// Copyright (C) 2020 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <https://www.gnu.org/licenses/>. + +// +build cgo + +package matrix + +import ( + "path/filepath" + + "maunium.net/go/mautrix/crypto" + + "maunium.net/go/gomuks/debug" +) + +type cryptoLogger struct{} + +func (c cryptoLogger) Error(message string, args ...interface{}) { + debug.Printf("[Crypto/Error] "+message, args...) +} + +func (c cryptoLogger) Warn(message string, args ...interface{}) { + debug.Printf("[Crypto/Warn] "+message, args...) +} + +func (c cryptoLogger) Debug(message string, args ...interface{}) { + debug.Printf("[Crypto/Debug] "+message, args...) +} + +func (c cryptoLogger) Trace(message string, args ...interface{}) { + debug.Printf("[Crypto/Trace] "+message, args...) +} + +func isBadEncryptError(err error) bool { + return err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession +} + +func (c *Container) initCrypto() error { + cryptoStore, err := crypto.NewGobStore(filepath.Join(c.config.DataDir, "crypto.gob")) + if err != nil { + return err + } + c.crypto = crypto.NewOlmMachine(c.client, cryptoLogger{}, cryptoStore, c.config.Rooms) + err = c.crypto.Load() + if err != nil { + return err + } + return nil +} diff --git a/matrix/matrix.go b/matrix/matrix.go index d0fd2f4..30e28fb 100644 --- a/matrix/matrix.go +++ b/matrix/matrix.go @@ -17,7 +17,6 @@ package matrix import ( - "bytes" "context" "crypto/tls" "encoding/gob" @@ -37,6 +36,7 @@ import ( "github.com/pkg/errors" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/id" @@ -55,6 +55,7 @@ import ( // It is used for all Matrix calls from the UI and Matrix event handlers. type Container struct { client *mautrix.Client + crypto CryptoInterface syncer *GomuksSyncer gmx ifc.Gomuks ui ifc.GomuksUI @@ -88,6 +89,16 @@ func (log mxLogger) Debugfln(message string, args ...interface{}) { debug.Printf("[Matrix] "+message, args...) } +type CryptoInterface interface { + Load() error + FlushStore() error + ProcessSyncResponse(resp *mautrix.RespSync, since string) + HandleMemberEvent(*event.Event) + DecryptMegolmEvent(*event.Event) (*event.Event, error) + EncryptMegolmEvent(id.RoomID, event.Type, event.Content) (*event.EncryptedEventContent, error) + ShareGroupSession(id.RoomID, []id.UserID) error +} + // InitClient initializes the mautrix client and connects to the homeserver specified in the config. func (c *Container) InitClient() error { if len(c.config.HS) == 0 { @@ -97,6 +108,7 @@ func (c *Container) InitClient() error { if c.client != nil { c.Stop() c.client = nil + c.crypto = nil } var mxid id.UserID @@ -112,6 +124,12 @@ func (c *Container) InitClient() error { return err } c.client.Logger = mxLogger{} + c.client.DeviceID = c.config.DeviceID + + err = c.initCrypto() + if err != nil { + return err + } if c.history == nil { c.history, err = NewHistoryManager(c.config.HistoryPath) @@ -159,7 +177,9 @@ func (c *Container) PasswordLogin(user, password string) error { func (c *Container) finishLogin(resp *mautrix.RespLogin) { c.client.SetCredentials(resp.UserID, resp.AccessToken) + c.client.DeviceID = resp.DeviceID c.config.UserID = resp.UserID + c.config.DeviceID = resp.DeviceID c.config.AccessToken = resp.AccessToken c.config.Save() @@ -247,9 +267,10 @@ func (c *Container) Login(user, password string) error { // Logout revokes the access token, stops the syncer and calls the OnLogout() method of the UI. func (c *Container) Logout() { c.client.Logout() - c.config.DeleteSession() c.Stop() + c.config.DeleteSession() c.client = nil + c.crypto = nil c.ui.OnLogout() } @@ -265,6 +286,13 @@ func (c *Container) Stop() { debug.Print("Error closing history manager:", err) } c.history = nil + if c.crypto != nil { + debug.Print("Flushing crypto store") + err = c.crypto.FlushStore() + if err != nil { + debug.Print("Error flushing crypto store:", err) + } + } } } @@ -315,8 +343,20 @@ func (c *Container) OnLogin() { debug.Print("Initializing syncer") c.syncer = NewGomuksSyncer(c.config.Rooms) + if c.crypto != nil { + c.syncer.OnSync(c.crypto.ProcessSyncResponse) + c.syncer.OnEventType(event.StateMember, func(source EventSource, evt *event.Event) { + // Don't spam the crypto module with member events of an initial sync + // TODO invalidate all group sessions when clearing cache? + if c.config.AuthCache.InitialSyncDone { + c.crypto.HandleMemberEvent(evt) + } + }) + c.syncer.OnEventType(event.EventEncrypted, c.HandleEncrypted) + } else { + c.syncer.OnEventType(event.EventEncrypted, c.HandleMessage) + } c.syncer.OnEventType(event.EventMessage, c.HandleMessage) - c.syncer.OnEventType(event.EventEncrypted, c.HandleMessage) c.syncer.OnEventType(event.EventSticker, c.HandleMessage) c.syncer.OnEventType(event.EventReaction, c.HandleMessage) c.syncer.OnEventType(event.EventRedaction, c.HandleRedaction) @@ -516,6 +556,17 @@ func (c *Container) HandleReaction(room *rooms.Room, reactsTo id.EventID, reactE } } +func (c *Container) HandleEncrypted(source EventSource, mxEvent *event.Event) { + evt, err := c.crypto.DecryptMegolmEvent(mxEvent) + if err != nil { + debug.Print("Failed to decrypt event:", err) + // TODO add decryption failed message instead of passing through directly + c.HandleMessage(source, mxEvent) + return + } + c.HandleMessage(source, evt) +} + // HandleMessage is the event handler for the m.room.message timeline event. func (c *Container) HandleMessage(source EventSource, mxEvent *event.Event) { room := c.GetOrCreateRoom(mxEvent.RoomID) @@ -526,13 +577,16 @@ func (c *Container) HandleMessage(source EventSource, mxEvent *event.Event) { return } - rel := mxEvent.Content.AsMessage().GetRelatesTo() - if editID := rel.GetReplaceID(); len(editID) > 0 { - c.HandleEdit(room, editID, muksevt.Wrap(mxEvent)) - return - } else if reactionID := rel.GetAnnotationID(); mxEvent.Type == event.EventReaction && len(reactionID) > 0 { - c.HandleReaction(room, reactionID, muksevt.Wrap(mxEvent)) - return + relatable, ok := mxEvent.Content.Parsed.(event.Relatable) + if ok { + rel := relatable.GetRelatesTo() + if editID := rel.GetReplaceID(); len(editID) > 0 { + c.HandleEdit(room, editID, muksevt.Wrap(mxEvent)) + return + } else if reactionID := rel.GetAnnotationID(); mxEvent.Type == event.EventReaction && len(reactionID) > 0 { + c.HandleReaction(room, reactionID, muksevt.Wrap(mxEvent)) + return + } } events, err := c.history.Append(room, []*event.Event{mxEvent}) @@ -824,8 +878,28 @@ func (c *Container) Redact(roomID id.RoomID, eventID id.EventID, reason string) func (c *Container) SendEvent(evt *muksevt.Event) (id.EventID, error) { defer debug.Recover() - c.client.UserTyping(evt.RoomID, false, 0) + _, _ = c.client.UserTyping(evt.RoomID, false, 0) c.typing = 0 + room := c.GetRoom(evt.RoomID) + if room != nil && room.Encrypted && c.crypto != nil && evt.Type != event.EventReaction { + encrypted, err := c.crypto.EncryptMegolmEvent(evt.RoomID, evt.Type, evt.Content) + if err != nil { + if isBadEncryptError(err) { + return "", err + } + debug.Print("Got", err, "while trying to encrypt message, sharing group session and trying again...") + err = c.crypto.ShareGroupSession(room.ID, room.GetMemberList()) + if err != nil { + return "", err + } + encrypted, err = c.crypto.EncryptMegolmEvent(evt.RoomID, evt.Type, evt.Content) + if err != nil { + return "", err + } + } + evt.Type = event.EventEncrypted + evt.Content = event.Content{Parsed: encrypted} + } resp, err := c.client.SendMessageEvent(evt.RoomID, evt.Type, &evt.Content, mautrix.ReqSendEvent{TransactionID: evt.Unsigned.TransactionID}) if err != nil { return "", err @@ -923,11 +997,21 @@ func (c *Container) GetHistory(room *rooms.Room, limit int) ([]*muksevt.Event, e return nil, err } debug.Printf("Loaded %d events for %s from server from %s to %s", len(resp.Chunk), room.ID, resp.Start, resp.End) - for _, evt := range resp.Chunk { + for i, evt := range resp.Chunk { err := evt.Content.ParseRaw(evt.Type) if err != nil { debug.Printf("Failed to unmarshal content of event %s (type %s) by %s in %s: %v\n%s", evt.ID, evt.Type.Repr(), evt.Sender, evt.RoomID, err, string(evt.Content.VeryRaw)) } + + if c.crypto != nil && evt.Type == event.EventEncrypted { + decrypted, err := c.crypto.DecryptMegolmEvent(evt) + if err != nil { + debug.Print("Failed to decrypt event:", err) + // TODO add decryption failed message instead of passing through directly + } else { + resp.Chunk[i] = decrypted + } + } } for _, evt := range resp.State { room.UpdateState(evt) @@ -991,7 +1075,7 @@ func cp(src, dst string) error { return out.Close() } -func (c *Container) DownloadToDisk(uri id.ContentURI, target string) (fullPath string, err error) { +func (c *Container) DownloadToDisk(uri id.ContentURI, file *attachment.EncryptedFile, target string) (fullPath string, err error) { cachePath := c.GetCachePath(uri) if target == "" { fullPath = cachePath @@ -1002,21 +1086,27 @@ func (c *Container) DownloadToDisk(uri id.ContentURI, target string) (fullPath s } if _, statErr := os.Stat(cachePath); os.IsNotExist(statErr) { - var file *os.File - file, err = os.OpenFile(cachePath, os.O_CREATE|os.O_WRONLY, 0600) + var body io.ReadCloser + body, err = c.client.Download(uri) if err != nil { return } - defer file.Close() - var body io.ReadCloser - body, err = c.client.Download(uri) + var data []byte + data, err = ioutil.ReadAll(body) + _ = body.Close() if err != nil { return } - defer body.Close() - _, err = io.Copy(file, body) + if file != nil { + data, err = file.Decrypt(data) + if err != nil { + return + } + } + + err = ioutil.WriteFile(cachePath, data, 0600) if err != nil { return } @@ -1036,7 +1126,7 @@ func (c *Container) DownloadToDisk(uri id.ContentURI, target string) (fullPath s // Download fetches the given Matrix content (mxc) URL and returns the data, homeserver, file ID and potential errors. // // The file will be either read from the media cache (if found) or downloaded from the server. -func (c *Container) Download(uri id.ContentURI) (data []byte, err error) { +func (c *Container) Download(uri id.ContentURI, file *attachment.EncryptedFile) (data []byte, err error) { cacheFile := c.GetCachePath(uri) var info os.FileInfo if info, err = os.Stat(cacheFile); err == nil && !info.IsDir() { @@ -1046,7 +1136,7 @@ func (c *Container) Download(uri id.ContentURI) (data []byte, err error) { } } - data, err = c.download(uri, cacheFile) + data, err = c.download(uri, file, cacheFile) return } @@ -1054,21 +1144,25 @@ func (c *Container) GetDownloadURL(uri id.ContentURI) string { return c.client.GetDownloadURL(uri) } -func (c *Container) download(uri id.ContentURI, cacheFile string) (data []byte, err error) { +func (c *Container) download(uri id.ContentURI, file *attachment.EncryptedFile, cacheFile string) (data []byte, err error) { var body io.ReadCloser body, err = c.client.Download(uri) if err != nil { return } - defer body.Close() - var buf bytes.Buffer - _, err = io.Copy(&buf, body) + data, err = ioutil.ReadAll(body) + _ = body.Close() if err != nil { return } - data = buf.Bytes() + if file != nil { + data, err = file.Decrypt(data) + if err != nil { + return + } + } err = ioutil.WriteFile(cacheFile, data, 0600) return diff --git a/matrix/nocrypto.go b/matrix/nocrypto.go new file mode 100644 index 0000000..979afda --- /dev/null +++ b/matrix/nocrypto.go @@ -0,0 +1,13 @@ +// This contains no-op stubs of the methods in crypto.go for non-cgo builds with crypto disabled. + +// +build !cgo + +package matrix + +func isBadEncryptError(err error) bool { + return false +} + +func (c *Container) initCrypto() error { + return nil +} diff --git a/matrix/rooms/room.go b/matrix/rooms/room.go index 0238cfb..d5d1d8f 100644 --- a/matrix/rooms/room.go +++ b/matrix/rooms/room.go @@ -412,7 +412,7 @@ func (room *Room) UpdateState(evt *event.Event) { case *event.TopicEventContent: room.topicCache = content.Topic case *event.EncryptionEventContent: - if content.Algorithm == event.AlgorithmMegolmV1 { + if content.Algorithm == id.AlgorithmMegolmV1 { room.Encrypted = true } } @@ -650,6 +650,17 @@ func (room *Room) GetMembers() map[id.UserID]*Member { return room.memberCache } +func (room *Room) GetMemberList() []id.UserID { + members := room.GetMembers() + memberList := make([]id.UserID, len(members)) + index := 0 + for userID, _ := range members { + memberList[index] = userID + index++ + } + return memberList +} + // GetMember returns the member with the given MXID. // If the member doesn't exist, nil is returned. func (room *Room) GetMember(userID id.UserID) *Member { diff --git a/matrix/rooms/roomcache.go b/matrix/rooms/roomcache.go index ffdcad1..067cbb6 100644 --- a/matrix/rooms/roomcache.go +++ b/matrix/rooms/roomcache.go @@ -27,6 +27,7 @@ import ( sync "github.com/sasha-s/go-deadlock" "maunium.net/go/gomuks/debug" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -67,6 +68,29 @@ func (cache *RoomCache) EnableUnloading() { cache.noUnload = false } +func (cache *RoomCache) IsEncrypted(roomID id.RoomID) bool { + room := cache.Get(roomID) + return room != nil && room.Encrypted +} + +func (cache *RoomCache) FindSharedRooms(userID id.UserID) (shared []id.RoomID) { + // FIXME this disables unloading so TouchNode wouldn't try to double-lock + cache.DisableUnloading() + cache.Lock() + for _, room := range cache.Map { + if !room.Encrypted { + continue + } + member, ok := room.GetMembers()[userID] + if ok && member.Membership == event.MembershipJoin { + shared = append(shared, room.ID) + } + } + cache.Unlock() + cache.EnableUnloading() + return +} + func (cache *RoomCache) LoadList() error { cache.Lock() defer cache.Unlock() diff --git a/matrix/sync.go b/matrix/sync.go index 85de68c..2136088 100644 --- a/matrix/sync.go +++ b/matrix/sync.go @@ -83,9 +83,11 @@ func (es EventSource) String() string { } type EventHandler func(source EventSource, event *event.Event) +type SyncHandler func(resp *mautrix.RespSync, since string) type GomuksSyncer struct { rooms *rooms.RoomCache + globalListeners []SyncHandler listeners map[event.Type][]EventHandler // event type to listeners array FirstSyncDone bool InitDoneCallback func() @@ -96,10 +98,11 @@ type GomuksSyncer struct { // NewGomuksSyncer returns an instantiated GomuksSyncer func NewGomuksSyncer(rooms *rooms.RoomCache) *GomuksSyncer { return &GomuksSyncer{ - rooms: rooms, - listeners: make(map[event.Type][]EventHandler), - FirstSyncDone: false, - Progress: StubSyncingModal{}, + rooms: rooms, + globalListeners: []SyncHandler{}, + listeners: make(map[event.Type][]EventHandler), + FirstSyncDone: false, + Progress: StubSyncingModal{}, } } @@ -109,23 +112,26 @@ func (s *GomuksSyncer) ProcessResponse(res *mautrix.RespSync, since string) (err s.rooms.DisableUnloading() } debug.Print("Received sync response") + s.Progress.SetMessage("Processing sync response") steps := len(res.Rooms.Join) + len(res.Rooms.Invite) + len(res.Rooms.Leave) - s.Progress.SetSteps(steps + 2) - s.Progress.SetMessage("Processing global events") - s.processSyncEvents(nil, res.Presence.Events, EventSourcePresence) - s.Progress.Step() - s.processSyncEvents(nil, res.AccountData.Events, EventSourceAccountData) - s.Progress.Step() + s.Progress.SetSteps(steps + 2 + len(s.globalListeners)) wait := &sync.WaitGroup{} - - wait.Add(steps) callback := func() { wait.Done() s.Progress.Step() } + wait.Add(len(s.globalListeners)) + s.notifyGlobalListeners(res, since, callback) + wait.Wait() + + s.processSyncEvents(nil, res.Presence.Events, EventSourcePresence) + s.Progress.Step() + s.processSyncEvents(nil, res.AccountData.Events, EventSourceAccountData) + s.Progress.Step() + + wait.Add(steps) - s.Progress.SetMessage("Processing room events") for roomID, roomData := range res.Rooms.Join { go s.processJoinedRoom(roomID, roomData, callback) } @@ -152,6 +158,15 @@ func (s *GomuksSyncer) ProcessResponse(res *mautrix.RespSync, since string) (err return } +func (s *GomuksSyncer) notifyGlobalListeners(res *mautrix.RespSync, since string, callback func()) { + for _, listener := range s.globalListeners { + go func(listener SyncHandler) { + listener(res, since) + callback() + }(listener) + } +} + func (s *GomuksSyncer) processJoinedRoom(roomID id.RoomID, roomData mautrix.SyncJoinedRoom, callback func()) { defer debug.Recover() room := s.rooms.GetOrCreate(roomID) @@ -239,6 +254,10 @@ func (s *GomuksSyncer) OnEventType(eventType event.Type, callback EventHandler) s.listeners[eventType] = append(s.listeners[eventType], callback) } +func (s *GomuksSyncer) OnSync(callback SyncHandler) { + s.globalListeners = append(s.globalListeners, callback) +} + func (s *GomuksSyncer) notifyListeners(source EventSource, evt *event.Event) { listeners, exists := s.listeners[evt.Type] if !exists { @@ -269,6 +288,7 @@ func (s *GomuksSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { event.StateCanonicalAlias, event.StatePowerLevels, event.StateTombstone, + event.StateEncryption, }, }, Timeline: mautrix.FilterPart{ |