diff options
Diffstat (limited to 'matrix/matrix.go')
-rw-r--r-- | matrix/matrix.go | 200 |
1 files changed, 140 insertions, 60 deletions
diff --git a/matrix/matrix.go b/matrix/matrix.go index d0fd2f4..e2e182e 100644 --- a/matrix/matrix.go +++ b/matrix/matrix.go @@ -17,7 +17,6 @@ package matrix import ( - "bytes" "context" "crypto/tls" "encoding/gob" @@ -37,6 +36,7 @@ import ( "github.com/pkg/errors" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/id" @@ -55,6 +55,7 @@ import ( // It is used for all Matrix calls from the UI and Matrix event handlers. type Container struct { client *mautrix.Client + crypto ifc.Crypto syncer *GomuksSyncer gmx ifc.Gomuks ui ifc.GomuksUI @@ -88,6 +89,10 @@ func (log mxLogger) Debugfln(message string, args ...interface{}) { debug.Printf("[Matrix] "+message, args...) } +func (c *Container) Crypto() ifc.Crypto { + return c.crypto +} + // InitClient initializes the mautrix client and connects to the homeserver specified in the config. func (c *Container) InitClient() error { if len(c.config.HS) == 0 { @@ -97,6 +102,7 @@ func (c *Container) InitClient() error { if c.client != nil { c.Stop() c.client = nil + c.crypto = nil } var mxid id.UserID @@ -112,6 +118,12 @@ func (c *Container) InitClient() error { return err } c.client.Logger = mxLogger{} + c.client.DeviceID = c.config.DeviceID + + err = c.initCrypto() + if err != nil { + return err + } if c.history == nil { c.history, err = NewHistoryManager(c.config.HistoryPath) @@ -159,7 +171,9 @@ func (c *Container) PasswordLogin(user, password string) error { func (c *Container) finishLogin(resp *mautrix.RespLogin) { c.client.SetCredentials(resp.UserID, resp.AccessToken) + c.client.DeviceID = resp.DeviceID c.config.UserID = resp.UserID + c.config.DeviceID = resp.DeviceID c.config.AccessToken = resp.AccessToken c.config.Save() @@ -247,9 +261,10 @@ func (c *Container) Login(user, password string) error { // Logout revokes the access token, stops the syncer and calls the OnLogout() method of the UI. func (c *Container) Logout() { c.client.Logout() - c.config.DeleteSession() c.Stop() + c.config.DeleteSession() c.client = nil + c.crypto = nil c.ui.OnLogout() } @@ -257,7 +272,10 @@ func (c *Container) Logout() { func (c *Container) Stop() { if c.running { debug.Print("Stopping Matrix container...") - c.stop <- true + select { + case c.stop <- true: + default: + } c.client.StopSync() debug.Print("Closing history manager...") err := c.history.Close() @@ -265,6 +283,13 @@ func (c *Container) Stop() { debug.Print("Error closing history manager:", err) } c.history = nil + if c.crypto != nil { + debug.Print("Flushing crypto store") + err = c.crypto.FlushStore() + if err != nil { + debug.Print("Error flushing crypto store:", err) + } + } } } @@ -315,8 +340,20 @@ func (c *Container) OnLogin() { debug.Print("Initializing syncer") c.syncer = NewGomuksSyncer(c.config.Rooms) + if c.crypto != nil { + c.syncer.OnSync(c.crypto.ProcessSyncResponse) + c.syncer.OnEventType(event.StateMember, func(source EventSource, evt *event.Event) { + // Don't spam the crypto module with member events of an initial sync + // TODO invalidate all group sessions when clearing cache? + if c.config.AuthCache.InitialSyncDone { + c.crypto.HandleMemberEvent(evt) + } + }) + c.syncer.OnEventType(event.EventEncrypted, c.HandleEncrypted) + } else { + c.syncer.OnEventType(event.EventEncrypted, c.HandleMessage) + } c.syncer.OnEventType(event.EventMessage, c.HandleMessage) - c.syncer.OnEventType(event.EventEncrypted, c.HandleMessage) c.syncer.OnEventType(event.EventSticker, c.HandleMessage) c.syncer.OnEventType(event.EventReaction, c.HandleMessage) c.syncer.OnEventType(event.EventRedaction, c.HandleRedaction) @@ -516,6 +553,17 @@ func (c *Container) HandleReaction(room *rooms.Room, reactsTo id.EventID, reactE } } +func (c *Container) HandleEncrypted(source EventSource, mxEvent *event.Event) { + evt, err := c.crypto.DecryptMegolmEvent(mxEvent) + if err != nil { + debug.Print("Failed to decrypt event:", err) + // TODO add decryption failed message instead of passing through directly + c.HandleMessage(source, mxEvent) + return + } + c.HandleMessage(source, evt) +} + // HandleMessage is the event handler for the m.room.message timeline event. func (c *Container) HandleMessage(source EventSource, mxEvent *event.Event) { room := c.GetOrCreateRoom(mxEvent.RoomID) @@ -526,13 +574,16 @@ func (c *Container) HandleMessage(source EventSource, mxEvent *event.Event) { return } - rel := mxEvent.Content.AsMessage().GetRelatesTo() - if editID := rel.GetReplaceID(); len(editID) > 0 { - c.HandleEdit(room, editID, muksevt.Wrap(mxEvent)) - return - } else if reactionID := rel.GetAnnotationID(); mxEvent.Type == event.EventReaction && len(reactionID) > 0 { - c.HandleReaction(room, reactionID, muksevt.Wrap(mxEvent)) - return + relatable, ok := mxEvent.Content.Parsed.(event.Relatable) + if ok { + rel := relatable.GetRelatesTo() + if editID := rel.GetReplaceID(); len(editID) > 0 { + c.HandleEdit(room, editID, muksevt.Wrap(mxEvent)) + return + } else if reactionID := rel.GetAnnotationID(); mxEvent.Type == event.EventReaction && len(reactionID) > 0 { + c.HandleReaction(room, reactionID, muksevt.Wrap(mxEvent)) + return + } } events, err := c.history.Append(room, []*event.Event{mxEvent}) @@ -635,26 +686,16 @@ func (c *Container) processOwnMembershipChange(evt *event.Event) { func (c *Container) parseReadReceipt(evt *event.Event) (largestTimestampEvent id.EventID) { var largestTimestamp int64 - for eventID, rawContent := range evt.Content.Raw { - content, ok := rawContent.(map[string]interface{}) - if !ok { - continue - } - - mRead, ok := content["m.read"].(map[string]interface{}) - if !ok { - continue - } - myInfo, ok := mRead[string(c.config.UserID)].(map[string]interface{}) + for eventID, receipts := range *evt.Content.AsReceipt() { + myInfo, ok := receipts.Read[c.config.UserID] if !ok { continue } - ts, ok := myInfo["ts"].(float64) - if int64(ts) > largestTimestamp { - largestTimestamp = int64(ts) - largestTimestampEvent = id.EventID(eventID) + if myInfo.Timestamp > largestTimestamp { + largestTimestamp = myInfo.Timestamp + largestTimestampEvent = eventID } } return @@ -681,19 +722,10 @@ func (c *Container) HandleReadReceipt(source EventSource, evt *event.Event) { func (c *Container) parseDirectChatInfo(evt *event.Event) map[*rooms.Room]bool { directChats := make(map[*rooms.Room]bool) - for _, rawRoomIDList := range evt.Content.Raw { - roomIDList, ok := rawRoomIDList.([]interface{}) - if !ok { - continue - } - - for _, rawRoomID := range roomIDList { - roomID, ok := rawRoomID.(string) - if !ok { - continue - } - - room := c.GetOrCreateRoom(id.RoomID(roomID)) + for _, roomIDList := range *evt.Content.AsDirectChats() { + for _, roomID := range roomIDList { + // TODO we shouldn't create direct chat rooms that we aren't in + room := c.GetOrCreateRoom(roomID) if room != nil && !room.HasLeft { directChats[room] = true } @@ -763,8 +795,13 @@ func (c *Container) HandleTyping(_ EventSource, evt *event.Event) { } func (c *Container) MarkRead(roomID id.RoomID, eventID id.EventID) { - urlPath := c.client.BuildURL("rooms", roomID, "receipt", "m.read", eventID) - _, _ = c.client.MakeRequest("POST", urlPath, struct{}{}, nil) + go func() { + defer debug.Recover() + err := c.client.MarkRead(roomID, eventID) + if err != nil { + debug.Print("Failed to mark %s in %s as read: %v", eventID, roomID, err) + } + }() } func (c *Container) PrepareMarkdownMessage(roomID id.RoomID, msgtype event.MessageType, text, html string, rel *ifc.Relation) *muksevt.Event { @@ -824,8 +861,28 @@ func (c *Container) Redact(roomID id.RoomID, eventID id.EventID, reason string) func (c *Container) SendEvent(evt *muksevt.Event) (id.EventID, error) { defer debug.Recover() - c.client.UserTyping(evt.RoomID, false, 0) + _, _ = c.client.UserTyping(evt.RoomID, false, 0) c.typing = 0 + room := c.GetRoom(evt.RoomID) + if room != nil && room.Encrypted && c.crypto != nil && evt.Type != event.EventReaction { + encrypted, err := c.crypto.EncryptMegolmEvent(evt.RoomID, evt.Type, evt.Content) + if err != nil { + if isBadEncryptError(err) { + return "", err + } + debug.Print("Got", err, "while trying to encrypt message, sharing group session and trying again...") + err = c.crypto.ShareGroupSession(room.ID, room.GetMemberList()) + if err != nil { + return "", err + } + encrypted, err = c.crypto.EncryptMegolmEvent(evt.RoomID, evt.Type, evt.Content) + if err != nil { + return "", err + } + } + evt.Type = event.EventEncrypted + evt.Content = event.Content{Parsed: encrypted} + } resp, err := c.client.SendMessageEvent(evt.RoomID, evt.Type, &evt.Content, mautrix.ReqSendEvent{TransactionID: evt.Unsigned.TransactionID}) if err != nil { return "", err @@ -923,11 +980,21 @@ func (c *Container) GetHistory(room *rooms.Room, limit int) ([]*muksevt.Event, e return nil, err } debug.Printf("Loaded %d events for %s from server from %s to %s", len(resp.Chunk), room.ID, resp.Start, resp.End) - for _, evt := range resp.Chunk { + for i, evt := range resp.Chunk { err := evt.Content.ParseRaw(evt.Type) if err != nil { debug.Printf("Failed to unmarshal content of event %s (type %s) by %s in %s: %v\n%s", evt.ID, evt.Type.Repr(), evt.Sender, evt.RoomID, err, string(evt.Content.VeryRaw)) } + + if c.crypto != nil && evt.Type == event.EventEncrypted { + decrypted, err := c.crypto.DecryptMegolmEvent(evt) + if err != nil { + debug.Print("Failed to decrypt event:", err) + // TODO add decryption failed message instead of passing through directly + } else { + resp.Chunk[i] = decrypted + } + } } for _, evt := range resp.State { room.UpdateState(evt) @@ -956,9 +1023,12 @@ func (c *Container) GetEvent(room *rooms.Room, eventID id.EventID) (*muksevt.Eve if err != nil { return nil, err } - evt = muksevt.Wrap(mxEvent) + err = mxEvent.Content.ParseRaw(mxEvent.Type) + if err != nil { + return nil, err + } debug.Printf("Loaded event %s from server", eventID) - return evt, nil + return muksevt.Wrap(mxEvent), nil } // GetOrCreateRoom gets the room instance stored in the session. @@ -991,7 +1061,7 @@ func cp(src, dst string) error { return out.Close() } -func (c *Container) DownloadToDisk(uri id.ContentURI, target string) (fullPath string, err error) { +func (c *Container) DownloadToDisk(uri id.ContentURI, file *attachment.EncryptedFile, target string) (fullPath string, err error) { cachePath := c.GetCachePath(uri) if target == "" { fullPath = cachePath @@ -1002,21 +1072,27 @@ func (c *Container) DownloadToDisk(uri id.ContentURI, target string) (fullPath s } if _, statErr := os.Stat(cachePath); os.IsNotExist(statErr) { - var file *os.File - file, err = os.OpenFile(cachePath, os.O_CREATE|os.O_WRONLY, 0600) + var body io.ReadCloser + body, err = c.client.Download(uri) if err != nil { return } - defer file.Close() - var body io.ReadCloser - body, err = c.client.Download(uri) + var data []byte + data, err = ioutil.ReadAll(body) + _ = body.Close() if err != nil { return } - defer body.Close() - _, err = io.Copy(file, body) + if file != nil { + data, err = file.Decrypt(data) + if err != nil { + return + } + } + + err = ioutil.WriteFile(cachePath, data, 0600) if err != nil { return } @@ -1036,7 +1112,7 @@ func (c *Container) DownloadToDisk(uri id.ContentURI, target string) (fullPath s // Download fetches the given Matrix content (mxc) URL and returns the data, homeserver, file ID and potential errors. // // The file will be either read from the media cache (if found) or downloaded from the server. -func (c *Container) Download(uri id.ContentURI) (data []byte, err error) { +func (c *Container) Download(uri id.ContentURI, file *attachment.EncryptedFile) (data []byte, err error) { cacheFile := c.GetCachePath(uri) var info os.FileInfo if info, err = os.Stat(cacheFile); err == nil && !info.IsDir() { @@ -1046,7 +1122,7 @@ func (c *Container) Download(uri id.ContentURI) (data []byte, err error) { } } - data, err = c.download(uri, cacheFile) + data, err = c.download(uri, file, cacheFile) return } @@ -1054,21 +1130,25 @@ func (c *Container) GetDownloadURL(uri id.ContentURI) string { return c.client.GetDownloadURL(uri) } -func (c *Container) download(uri id.ContentURI, cacheFile string) (data []byte, err error) { +func (c *Container) download(uri id.ContentURI, file *attachment.EncryptedFile, cacheFile string) (data []byte, err error) { var body io.ReadCloser body, err = c.client.Download(uri) if err != nil { return } - defer body.Close() - var buf bytes.Buffer - _, err = io.Copy(&buf, body) + data, err = ioutil.ReadAll(body) + _ = body.Close() if err != nil { return } - data = buf.Bytes() + if file != nil { + data, err = file.Decrypt(data) + if err != nil { + return + } + } err = ioutil.WriteFile(cacheFile, data, 0600) return |