aboutsummaryrefslogtreecommitdiff
path: root/matrix
diff options
context:
space:
mode:
Diffstat (limited to 'matrix')
-rw-r--r--matrix/matrix.go42
1 files changed, 26 insertions, 16 deletions
diff --git a/matrix/matrix.go b/matrix/matrix.go
index 8d7595e..27ed053 100644
--- a/matrix/matrix.go
+++ b/matrix/matrix.go
@@ -17,7 +17,6 @@
package matrix
import (
- "bytes"
"context"
"crypto/tls"
"encoding/gob"
@@ -38,6 +37,7 @@ import (
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
+ "maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
@@ -1061,7 +1061,7 @@ func cp(src, dst string) error {
return out.Close()
}
-func (c *Container) DownloadToDisk(uri id.ContentURI, target string) (fullPath string, err error) {
+func (c *Container) DownloadToDisk(uri id.ContentURI, file *attachment.EncryptedFile, target string) (fullPath string, err error) {
cachePath := c.GetCachePath(uri)
if target == "" {
fullPath = cachePath
@@ -1072,21 +1072,27 @@ func (c *Container) DownloadToDisk(uri id.ContentURI, target string) (fullPath s
}
if _, statErr := os.Stat(cachePath); os.IsNotExist(statErr) {
- var file *os.File
- file, err = os.OpenFile(cachePath, os.O_CREATE|os.O_WRONLY, 0600)
+ var body io.ReadCloser
+ body, err = c.client.Download(uri)
if err != nil {
return
}
- defer file.Close()
- var body io.ReadCloser
- body, err = c.client.Download(uri)
+ var data []byte
+ data, err = ioutil.ReadAll(body)
+ _ = body.Close()
if err != nil {
return
}
- defer body.Close()
- _, err = io.Copy(file, body)
+ if file != nil {
+ data, err = file.Decrypt(data)
+ if err != nil {
+ return
+ }
+ }
+
+ err = ioutil.WriteFile(cachePath, data, 0600)
if err != nil {
return
}
@@ -1106,7 +1112,7 @@ func (c *Container) DownloadToDisk(uri id.ContentURI, target string) (fullPath s
// Download fetches the given Matrix content (mxc) URL and returns the data, homeserver, file ID and potential errors.
//
// The file will be either read from the media cache (if found) or downloaded from the server.
-func (c *Container) Download(uri id.ContentURI) (data []byte, err error) {
+func (c *Container) Download(uri id.ContentURI, file *attachment.EncryptedFile) (data []byte, err error) {
cacheFile := c.GetCachePath(uri)
var info os.FileInfo
if info, err = os.Stat(cacheFile); err == nil && !info.IsDir() {
@@ -1116,7 +1122,7 @@ func (c *Container) Download(uri id.ContentURI) (data []byte, err error) {
}
}
- data, err = c.download(uri, cacheFile)
+ data, err = c.download(uri, file, cacheFile)
return
}
@@ -1124,21 +1130,25 @@ func (c *Container) GetDownloadURL(uri id.ContentURI) string {
return c.client.GetDownloadURL(uri)
}
-func (c *Container) download(uri id.ContentURI, cacheFile string) (data []byte, err error) {
+func (c *Container) download(uri id.ContentURI, file *attachment.EncryptedFile, cacheFile string) (data []byte, err error) {
var body io.ReadCloser
body, err = c.client.Download(uri)
if err != nil {
return
}
- defer body.Close()
- var buf bytes.Buffer
- _, err = io.Copy(&buf, body)
+ data, err = ioutil.ReadAll(body)
+ _ = body.Close()
if err != nil {
return
}
- data = buf.Bytes()
+ if file != nil {
+ data, err = file.Decrypt(data)
+ if err != nil {
+ return
+ }
+ }
err = ioutil.WriteFile(cacheFile, data, 0600)
return