aboutsummaryrefslogtreecommitdiff
path: root/tests/test_olm_decrypt.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_olm_decrypt.cpp')
-rw-r--r--tests/test_olm_decrypt.cpp41
1 files changed, 28 insertions, 13 deletions
diff --git a/tests/test_olm_decrypt.cpp b/tests/test_olm_decrypt.cpp
index 95cb18e..4a1fb97 100644
--- a/tests/test_olm_decrypt.cpp
+++ b/tests/test_olm_decrypt.cpp
@@ -1,11 +1,16 @@
#include "olm/olm.h"
#include "unittest.hh"
-const char * test_cases[] = {
- "41776f",
- "7fff6f0101346d671201",
- "ee776f41496f674177804177778041776f6716670a677d6f670a67c2677d",
- "e9e9c9c1e9e9c9e9c9c1e9e9c9c1",
+struct test_case {
+ const char *msghex;
+ const char *expected_error;
+};
+
+const test_case test_cases[] = {
+ { "41776f", "BAD_MESSAGE_FORMAT" },
+ { "7fff6f0101346d671201", "BAD_MESSAGE_FORMAT" },
+ { "ee776f41496f674177804177778041776f6716670a677d6f670a67c2677d", "BAD_MESSAGE_FORMAT" },
+ { "e9e9c9c1e9e9c9e9c9c1e9e9c9c1", "BAD_MESSAGE_FORMAT" },
};
@@ -31,29 +36,39 @@ void decode_hex(
}
}
-void decrypt_case(int message_type, const char * test_case) {
+void decrypt_case(int message_type, const test_case * test_case) {
std::uint8_t session_memory[olm_session_size()];
::OlmSession * session = ::olm_session(session_memory);
std::uint8_t pickled[strlen(session_data)];
::memcpy(pickled, session_data, sizeof(pickled));
- ::olm_unpickle_session(session, "", 0, pickled, sizeof(pickled));
+ assert_not_equals(
+ ::olm_error(),
+ ::olm_unpickle_session(session, "", 0, pickled, sizeof(pickled))
+ );
- std::size_t message_length = strlen(test_case) / 2;
+ std::size_t message_length = strlen(test_case->msghex) / 2;
std::uint8_t * message = (std::uint8_t *) ::malloc(message_length);
- decode_hex(test_case, message, message_length);
+ decode_hex(test_case->msghex, message, message_length);
size_t max_length = olm_decrypt_max_plaintext_length(
session, message_type, message, message_length
);
- if (max_length == std::size_t(-1)) {
+ if (test_case->expected_error) {
+ assert_equals(::olm_error(), max_length);
+ assert_equals(
+ std::string(test_case->expected_error),
+ std::string(::olm_session_last_error(session))
+ );
free(message);
return;
}
+ assert_not_equals(::olm_error(), max_length);
+
uint8_t plaintext[max_length];
- decode_hex(test_case, message, message_length);
+ decode_hex(test_case->msghex, message, message_length);
olm_decrypt(
session, message_type,
message, message_length,
@@ -67,8 +82,8 @@ int main() {
{
TestCase my_test("Olm decrypt test");
-for (unsigned int i = 0; i < sizeof(test_cases)/ sizeof(const char *); ++i) {
- decrypt_case(0, test_cases[i]);
+for (unsigned int i = 0; i < sizeof(test_cases)/ sizeof(test_cases[0]); ++i) {
+ decrypt_case(0, &test_cases[i]);
}
}