From 158f7ee8919929b4daefb484d1a232d367d0c8e5 Mon Sep 17 00:00:00 2001
From: Mark Haines <mark.haines@matrix.org>
Date: Fri, 7 Aug 2015 19:33:48 +0100
Subject: Fix crash where the message length was shorter than the length of the
 mac

---
 src/message.cpp            |  1 +
 tests/test_olm_decrypt.cpp | 11 ++++++++---
 2 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/src/message.cpp b/src/message.cpp
index ffb9f6c..f98dfe5 100644
--- a/src/message.cpp
+++ b/src/message.cpp
@@ -213,6 +213,7 @@ void olm::decode_message(
     reader.ciphertext_length = 0;
 
     if (pos == end) return;
+    if (input_length < mac_length) return;
     reader.version = *(pos++);
 
     while (pos != end) {
diff --git a/tests/test_olm_decrypt.cpp b/tests/test_olm_decrypt.cpp
index 3f2a994..2a2db98 100644
--- a/tests/test_olm_decrypt.cpp
+++ b/tests/test_olm_decrypt.cpp
@@ -5,6 +5,7 @@ const char * test_cases[] = {
     "41776f",
     "7fff6f0101346d671201",
     "ee776f41496f674177804177778041776f6716670a677d6f670a67c2677d",
+    "e9e9c9c1e9e9c9e9c9c1e9e9c9c1",
 };
 
 
@@ -39,14 +40,17 @@ void decrypt_case(int message_type, const char * test_case) {
     ::olm_unpickle_session(session, "", 0, pickled, sizeof(pickled));
 
     std::size_t message_length = strlen(test_case) / 2;
-    std::uint8_t message[message_length];
+    std::uint8_t * message = (std::uint8_t *) ::malloc(message_length);
     decode_hex(test_case, message, message_length);
 
     size_t max_length = olm_decrypt_max_plaintext_length(
         session, message_type, message, message_length
     );
 
-    if (max_length == std::size_t(-1)) return;
+    if (max_length == std::size_t(-1)) {
+        free(message);
+        return;
+    }
 
     uint8_t plaintext[max_length];
     decode_hex(test_case, message, message_length);
@@ -55,12 +59,13 @@ void decrypt_case(int message_type, const char * test_case) {
         message, message_length,
         plaintext, max_length
     );
+    free(message);
 }
 
 
 int main() {
 {
-TestCase("Olm decrypt test");
+TestCase my_test("Olm decrypt test");
 
 for (int i = 0; i < sizeof(test_cases)/ sizeof(const char *); ++i) {
     decrypt_case(0, test_cases[i]);
-- 
cgit v1.2.3-70-g09d2