aboutsummaryrefslogtreecommitdiff
path: root/matrix/matrix.go
diff options
context:
space:
mode:
Diffstat (limited to 'matrix/matrix.go')
-rw-r--r--matrix/matrix.go200
1 files changed, 140 insertions, 60 deletions
diff --git a/matrix/matrix.go b/matrix/matrix.go
index d0fd2f4..e2e182e 100644
--- a/matrix/matrix.go
+++ b/matrix/matrix.go
@@ -17,7 +17,6 @@
package matrix
import (
- "bytes"
"context"
"crypto/tls"
"encoding/gob"
@@ -37,6 +36,7 @@ import (
"github.com/pkg/errors"
"maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
@@ -55,6 +55,7 @@ import (
// It is used for all Matrix calls from the UI and Matrix event handlers.
type Container struct {
client *mautrix.Client
+ crypto ifc.Crypto
syncer *GomuksSyncer
gmx ifc.Gomuks
ui ifc.GomuksUI
@@ -88,6 +89,10 @@ func (log mxLogger) Debugfln(message string, args ...interface{}) {
debug.Printf("[Matrix] "+message, args...)
}
+func (c *Container) Crypto() ifc.Crypto {
+ return c.crypto
+}
+
// InitClient initializes the mautrix client and connects to the homeserver specified in the config.
func (c *Container) InitClient() error {
if len(c.config.HS) == 0 {
@@ -97,6 +102,7 @@ func (c *Container) InitClient() error {
if c.client != nil {
c.Stop()
c.client = nil
+ c.crypto = nil
}
var mxid id.UserID
@@ -112,6 +118,12 @@ func (c *Container) InitClient() error {
return err
}
c.client.Logger = mxLogger{}
+ c.client.DeviceID = c.config.DeviceID
+
+ err = c.initCrypto()
+ if err != nil {
+ return err
+ }
if c.history == nil {
c.history, err = NewHistoryManager(c.config.HistoryPath)
@@ -159,7 +171,9 @@ func (c *Container) PasswordLogin(user, password string) error {
func (c *Container) finishLogin(resp *mautrix.RespLogin) {
c.client.SetCredentials(resp.UserID, resp.AccessToken)
+ c.client.DeviceID = resp.DeviceID
c.config.UserID = resp.UserID
+ c.config.DeviceID = resp.DeviceID
c.config.AccessToken = resp.AccessToken
c.config.Save()
@@ -247,9 +261,10 @@ func (c *Container) Login(user, password string) error {
// Logout revokes the access token, stops the syncer and calls the OnLogout() method of the UI.
func (c *Container) Logout() {
c.client.Logout()
- c.config.DeleteSession()
c.Stop()
+ c.config.DeleteSession()
c.client = nil
+ c.crypto = nil
c.ui.OnLogout()
}
@@ -257,7 +272,10 @@ func (c *Container) Logout() {
func (c *Container) Stop() {
if c.running {
debug.Print("Stopping Matrix container...")
- c.stop <- true
+ select {
+ case c.stop <- true:
+ default:
+ }
c.client.StopSync()
debug.Print("Closing history manager...")
err := c.history.Close()
@@ -265,6 +283,13 @@ func (c *Container) Stop() {
debug.Print("Error closing history manager:", err)
}
c.history = nil
+ if c.crypto != nil {
+ debug.Print("Flushing crypto store")
+ err = c.crypto.FlushStore()
+ if err != nil {
+ debug.Print("Error flushing crypto store:", err)
+ }
+ }
}
}
@@ -315,8 +340,20 @@ func (c *Container) OnLogin() {
debug.Print("Initializing syncer")
c.syncer = NewGomuksSyncer(c.config.Rooms)
+ if c.crypto != nil {
+ c.syncer.OnSync(c.crypto.ProcessSyncResponse)
+ c.syncer.OnEventType(event.StateMember, func(source EventSource, evt *event.Event) {
+ // Don't spam the crypto module with member events of an initial sync
+ // TODO invalidate all group sessions when clearing cache?
+ if c.config.AuthCache.InitialSyncDone {
+ c.crypto.HandleMemberEvent(evt)
+ }
+ })
+ c.syncer.OnEventType(event.EventEncrypted, c.HandleEncrypted)
+ } else {
+ c.syncer.OnEventType(event.EventEncrypted, c.HandleMessage)
+ }
c.syncer.OnEventType(event.EventMessage, c.HandleMessage)
- c.syncer.OnEventType(event.EventEncrypted, c.HandleMessage)
c.syncer.OnEventType(event.EventSticker, c.HandleMessage)
c.syncer.OnEventType(event.EventReaction, c.HandleMessage)
c.syncer.OnEventType(event.EventRedaction, c.HandleRedaction)
@@ -516,6 +553,17 @@ func (c *Container) HandleReaction(room *rooms.Room, reactsTo id.EventID, reactE
}
}
+func (c *Container) HandleEncrypted(source EventSource, mxEvent *event.Event) {
+ evt, err := c.crypto.DecryptMegolmEvent(mxEvent)
+ if err != nil {
+ debug.Print("Failed to decrypt event:", err)
+ // TODO add decryption failed message instead of passing through directly
+ c.HandleMessage(source, mxEvent)
+ return
+ }
+ c.HandleMessage(source, evt)
+}
+
// HandleMessage is the event handler for the m.room.message timeline event.
func (c *Container) HandleMessage(source EventSource, mxEvent *event.Event) {
room := c.GetOrCreateRoom(mxEvent.RoomID)
@@ -526,13 +574,16 @@ func (c *Container) HandleMessage(source EventSource, mxEvent *event.Event) {
return
}
- rel := mxEvent.Content.AsMessage().GetRelatesTo()
- if editID := rel.GetReplaceID(); len(editID) > 0 {
- c.HandleEdit(room, editID, muksevt.Wrap(mxEvent))
- return
- } else if reactionID := rel.GetAnnotationID(); mxEvent.Type == event.EventReaction && len(reactionID) > 0 {
- c.HandleReaction(room, reactionID, muksevt.Wrap(mxEvent))
- return
+ relatable, ok := mxEvent.Content.Parsed.(event.Relatable)
+ if ok {
+ rel := relatable.GetRelatesTo()
+ if editID := rel.GetReplaceID(); len(editID) > 0 {
+ c.HandleEdit(room, editID, muksevt.Wrap(mxEvent))
+ return
+ } else if reactionID := rel.GetAnnotationID(); mxEvent.Type == event.EventReaction && len(reactionID) > 0 {
+ c.HandleReaction(room, reactionID, muksevt.Wrap(mxEvent))
+ return
+ }
}
events, err := c.history.Append(room, []*event.Event{mxEvent})
@@ -635,26 +686,16 @@ func (c *Container) processOwnMembershipChange(evt *event.Event) {
func (c *Container) parseReadReceipt(evt *event.Event) (largestTimestampEvent id.EventID) {
var largestTimestamp int64
- for eventID, rawContent := range evt.Content.Raw {
- content, ok := rawContent.(map[string]interface{})
- if !ok {
- continue
- }
-
- mRead, ok := content["m.read"].(map[string]interface{})
- if !ok {
- continue
- }
- myInfo, ok := mRead[string(c.config.UserID)].(map[string]interface{})
+ for eventID, receipts := range *evt.Content.AsReceipt() {
+ myInfo, ok := receipts.Read[c.config.UserID]
if !ok {
continue
}
- ts, ok := myInfo["ts"].(float64)
- if int64(ts) > largestTimestamp {
- largestTimestamp = int64(ts)
- largestTimestampEvent = id.EventID(eventID)
+ if myInfo.Timestamp > largestTimestamp {
+ largestTimestamp = myInfo.Timestamp
+ largestTimestampEvent = eventID
}
}
return
@@ -681,19 +722,10 @@ func (c *Container) HandleReadReceipt(source EventSource, evt *event.Event) {
func (c *Container) parseDirectChatInfo(evt *event.Event) map[*rooms.Room]bool {
directChats := make(map[*rooms.Room]bool)
- for _, rawRoomIDList := range evt.Content.Raw {
- roomIDList, ok := rawRoomIDList.([]interface{})
- if !ok {
- continue
- }
-
- for _, rawRoomID := range roomIDList {
- roomID, ok := rawRoomID.(string)
- if !ok {
- continue
- }
-
- room := c.GetOrCreateRoom(id.RoomID(roomID))
+ for _, roomIDList := range *evt.Content.AsDirectChats() {
+ for _, roomID := range roomIDList {
+ // TODO we shouldn't create direct chat rooms that we aren't in
+ room := c.GetOrCreateRoom(roomID)
if room != nil && !room.HasLeft {
directChats[room] = true
}
@@ -763,8 +795,13 @@ func (c *Container) HandleTyping(_ EventSource, evt *event.Event) {
}
func (c *Container) MarkRead(roomID id.RoomID, eventID id.EventID) {
- urlPath := c.client.BuildURL("rooms", roomID, "receipt", "m.read", eventID)
- _, _ = c.client.MakeRequest("POST", urlPath, struct{}{}, nil)
+ go func() {
+ defer debug.Recover()
+ err := c.client.MarkRead(roomID, eventID)
+ if err != nil {
+ debug.Print("Failed to mark %s in %s as read: %v", eventID, roomID, err)
+ }
+ }()
}
func (c *Container) PrepareMarkdownMessage(roomID id.RoomID, msgtype event.MessageType, text, html string, rel *ifc.Relation) *muksevt.Event {
@@ -824,8 +861,28 @@ func (c *Container) Redact(roomID id.RoomID, eventID id.EventID, reason string)
func (c *Container) SendEvent(evt *muksevt.Event) (id.EventID, error) {
defer debug.Recover()
- c.client.UserTyping(evt.RoomID, false, 0)
+ _, _ = c.client.UserTyping(evt.RoomID, false, 0)
c.typing = 0
+ room := c.GetRoom(evt.RoomID)
+ if room != nil && room.Encrypted && c.crypto != nil && evt.Type != event.EventReaction {
+ encrypted, err := c.crypto.EncryptMegolmEvent(evt.RoomID, evt.Type, evt.Content)
+ if err != nil {
+ if isBadEncryptError(err) {
+ return "", err
+ }
+ debug.Print("Got", err, "while trying to encrypt message, sharing group session and trying again...")
+ err = c.crypto.ShareGroupSession(room.ID, room.GetMemberList())
+ if err != nil {
+ return "", err
+ }
+ encrypted, err = c.crypto.EncryptMegolmEvent(evt.RoomID, evt.Type, evt.Content)
+ if err != nil {
+ return "", err
+ }
+ }
+ evt.Type = event.EventEncrypted
+ evt.Content = event.Content{Parsed: encrypted}
+ }
resp, err := c.client.SendMessageEvent(evt.RoomID, evt.Type, &evt.Content, mautrix.ReqSendEvent{TransactionID: evt.Unsigned.TransactionID})
if err != nil {
return "", err
@@ -923,11 +980,21 @@ func (c *Container) GetHistory(room *rooms.Room, limit int) ([]*muksevt.Event, e
return nil, err
}
debug.Printf("Loaded %d events for %s from server from %s to %s", len(resp.Chunk), room.ID, resp.Start, resp.End)
- for _, evt := range resp.Chunk {
+ for i, evt := range resp.Chunk {
err := evt.Content.ParseRaw(evt.Type)
if err != nil {
debug.Printf("Failed to unmarshal content of event %s (type %s) by %s in %s: %v\n%s", evt.ID, evt.Type.Repr(), evt.Sender, evt.RoomID, err, string(evt.Content.VeryRaw))
}
+
+ if c.crypto != nil && evt.Type == event.EventEncrypted {
+ decrypted, err := c.crypto.DecryptMegolmEvent(evt)
+ if err != nil {
+ debug.Print("Failed to decrypt event:", err)
+ // TODO add decryption failed message instead of passing through directly
+ } else {
+ resp.Chunk[i] = decrypted
+ }
+ }
}
for _, evt := range resp.State {
room.UpdateState(evt)
@@ -956,9 +1023,12 @@ func (c *Container) GetEvent(room *rooms.Room, eventID id.EventID) (*muksevt.Eve
if err != nil {
return nil, err
}
- evt = muksevt.Wrap(mxEvent)
+ err = mxEvent.Content.ParseRaw(mxEvent.Type)
+ if err != nil {
+ return nil, err
+ }
debug.Printf("Loaded event %s from server", eventID)
- return evt, nil
+ return muksevt.Wrap(mxEvent), nil
}
// GetOrCreateRoom gets the room instance stored in the session.
@@ -991,7 +1061,7 @@ func cp(src, dst string) error {
return out.Close()
}
-func (c *Container) DownloadToDisk(uri id.ContentURI, target string) (fullPath string, err error) {
+func (c *Container) DownloadToDisk(uri id.ContentURI, file *attachment.EncryptedFile, target string) (fullPath string, err error) {
cachePath := c.GetCachePath(uri)
if target == "" {
fullPath = cachePath
@@ -1002,21 +1072,27 @@ func (c *Container) DownloadToDisk(uri id.ContentURI, target string) (fullPath s
}
if _, statErr := os.Stat(cachePath); os.IsNotExist(statErr) {
- var file *os.File
- file, err = os.OpenFile(cachePath, os.O_CREATE|os.O_WRONLY, 0600)
+ var body io.ReadCloser
+ body, err = c.client.Download(uri)
if err != nil {
return
}
- defer file.Close()
- var body io.ReadCloser
- body, err = c.client.Download(uri)
+ var data []byte
+ data, err = ioutil.ReadAll(body)
+ _ = body.Close()
if err != nil {
return
}
- defer body.Close()
- _, err = io.Copy(file, body)
+ if file != nil {
+ data, err = file.Decrypt(data)
+ if err != nil {
+ return
+ }
+ }
+
+ err = ioutil.WriteFile(cachePath, data, 0600)
if err != nil {
return
}
@@ -1036,7 +1112,7 @@ func (c *Container) DownloadToDisk(uri id.ContentURI, target string) (fullPath s
// Download fetches the given Matrix content (mxc) URL and returns the data, homeserver, file ID and potential errors.
//
// The file will be either read from the media cache (if found) or downloaded from the server.
-func (c *Container) Download(uri id.ContentURI) (data []byte, err error) {
+func (c *Container) Download(uri id.ContentURI, file *attachment.EncryptedFile) (data []byte, err error) {
cacheFile := c.GetCachePath(uri)
var info os.FileInfo
if info, err = os.Stat(cacheFile); err == nil && !info.IsDir() {
@@ -1046,7 +1122,7 @@ func (c *Container) Download(uri id.ContentURI) (data []byte, err error) {
}
}
- data, err = c.download(uri, cacheFile)
+ data, err = c.download(uri, file, cacheFile)
return
}
@@ -1054,21 +1130,25 @@ func (c *Container) GetDownloadURL(uri id.ContentURI) string {
return c.client.GetDownloadURL(uri)
}
-func (c *Container) download(uri id.ContentURI, cacheFile string) (data []byte, err error) {
+func (c *Container) download(uri id.ContentURI, file *attachment.EncryptedFile, cacheFile string) (data []byte, err error) {
var body io.ReadCloser
body, err = c.client.Download(uri)
if err != nil {
return
}
- defer body.Close()
- var buf bytes.Buffer
- _, err = io.Copy(&buf, body)
+ data, err = ioutil.ReadAll(body)
+ _ = body.Close()
if err != nil {
return
}
- data = buf.Bytes()
+ if file != nil {
+ data, err = file.Decrypt(data)
+ if err != nil {
+ return
+ }
+ }
err = ioutil.WriteFile(cacheFile, data, 0600)
return