diff options
Diffstat (limited to 'matrix')
-rw-r--r-- | matrix/matrix.go | 96 |
1 files changed, 75 insertions, 21 deletions
diff --git a/matrix/matrix.go b/matrix/matrix.go index 852a410..651f6bb 100644 --- a/matrix/matrix.go +++ b/matrix/matrix.go @@ -29,7 +29,6 @@ import ( "os" "path" "path/filepath" - "regexp" "runtime" dbg "runtime/debug" "time" @@ -389,6 +388,10 @@ func (c *Container) HandlePreferences(source EventSource, evt *mautrix.Event) { } } +func (c *Container) Preferences() *config.UserPreferences { + return &c.config.Preferences +} + func (c *Container) SendPreferencesToMatrix() { defer debug.Recover() debug.Print("Sending updated preferences:", c.config.Preferences) @@ -926,22 +929,73 @@ func (c *Container) GetRoom(roomID string) *rooms.Room { return c.config.Rooms.Get(roomID) } -var mxcRegex = regexp.MustCompile("mxc://(.+)/(.+)") +func cp(src, dst string) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() -// Download fetches the given Matrix content (mxc) URL and returns the data, homeserver, file ID and potential errors. -// -// The file will be either read from the media cache (if found) or downloaded from the server. -func (c *Container) Download(mxcURL string) (data []byte, hs, id string, err error) { - parts := mxcRegex.FindStringSubmatch(mxcURL) - if parts == nil || len(parts) != 3 { - err = fmt.Errorf("invalid matrix content URL") - return + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() + + _, err = io.Copy(out, in) + if err != nil { + return err + } + return out.Close() +} + +func (c *Container) DownloadToDisk(uri mautrix.ContentURI, target string) (fullPath string, err error) { + cachePath := c.GetCachePath(uri) + if target == "" { + fullPath = cachePath + } else if !path.IsAbs(target) { + fullPath = path.Join(c.config.DownloadDir, target) + } else { + fullPath = target + } + + if _, statErr := os.Stat(cachePath); os.IsNotExist(statErr) { + var file *os.File + file, err = os.OpenFile(cachePath, os.O_CREATE|os.O_WRONLY, 0600) + if err != nil { + return + } + defer file.Close() + + var resp *http.Response + resp, err = c.client.Client.Get(c.GetDownloadURL(uri)) + if err != nil { + return + } + defer resp.Body.Close() + + _, err = io.Copy(file, resp.Body) + if err != nil { + return + } } - hs = parts[1] - id = parts[2] + if fullPath != cachePath { + err = os.MkdirAll(path.Dir(fullPath), 0700) + if err != nil { + return + } + err = cp(cachePath, fullPath) + } + + return +} - cacheFile := c.GetCachePath(hs, id) +// Download fetches the given Matrix content (mxc) URL and returns the data, homeserver, file ID and potential errors. +// +// The file will be either read from the media cache (if found) or downloaded from the server. +func (c *Container) Download(uri mautrix.ContentURI) (data []byte, err error) { + cacheFile := c.GetCachePath(uri) var info os.FileInfo if info, err = os.Stat(cacheFile); err == nil && !info.IsDir() { data, err = ioutil.ReadFile(cacheFile) @@ -950,22 +1004,22 @@ func (c *Container) Download(mxcURL string) (data []byte, hs, id string, err err } } - data, err = c.download(hs, id, cacheFile) + data, err = c.download(uri, cacheFile) return } -func (c *Container) GetDownloadURL(hs, id string) string { +func (c *Container) GetDownloadURL(uri mautrix.ContentURI) string { dlURL, _ := url.Parse(c.client.HomeserverURL.String()) if dlURL.Scheme == "" { dlURL.Scheme = "https" } - dlURL.Path = path.Join(dlURL.Path, "/_matrix/media/v1/download", hs, id) + dlURL.Path = path.Join(dlURL.Path, "/_matrix/media/v1/download", uri.Homeserver, uri.FileID) return dlURL.String() } -func (c *Container) download(hs, id, cacheFile string) (data []byte, err error) { +func (c *Container) download(uri mautrix.ContentURI, cacheFile string) (data []byte, err error) { var resp *http.Response - resp, err = c.client.Client.Get(c.GetDownloadURL(hs, id)) + resp, err = c.client.Client.Get(c.GetDownloadURL(uri)) if err != nil { return } @@ -985,13 +1039,13 @@ func (c *Container) download(hs, id, cacheFile string) (data []byte, err error) // GetCachePath gets the path to the cached version of the given homeserver:fileID combination. // The file may or may not exist, use Download() to ensure it has been cached. -func (c *Container) GetCachePath(homeserver, fileID string) string { - dir := filepath.Join(c.config.MediaDir, homeserver) +func (c *Container) GetCachePath(uri mautrix.ContentURI) string { + dir := filepath.Join(c.config.MediaDir, uri.Homeserver) err := os.MkdirAll(dir, 0700) if err != nil { return "" } - return filepath.Join(dir, fileID) + return filepath.Join(dir, uri.FileID) } |