aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2016-05-18 17:23:09 +0100
committerRichard van der Hoff <richard@matrix.org>2016-05-24 13:39:34 +0100
commit39ad75314b9e28053f568ed6a4109f5d3a9468fe (patch)
tree72f7453ebbcbaa2513391c87b8b960092bb05ffa
parent8b1514c0a653ccc3f49db70131d7d4f7524f1f9b (diff)
Implement decrypting inbound group messages
Includes creation of inbound sessions, etc
-rw-r--r--include/olm/error.h3
-rw-r--r--include/olm/inbound_group_session.h153
-rw-r--r--include/olm/message.h24
-rw-r--r--include/olm/olm.h1
-rw-r--r--src/inbound_group_session.c199
-rw-r--r--src/message.cpp42
-rw-r--r--tests/test_group_session.cpp42
-rw-r--r--tests/test_message.cpp22
8 files changed, 480 insertions, 6 deletions
diff --git a/include/olm/error.h b/include/olm/error.h
index 87e019a..3f74992 100644
--- a/include/olm/error.h
+++ b/include/olm/error.h
@@ -32,6 +32,9 @@ enum OlmErrorCode {
OLM_UNKNOWN_PICKLE_VERSION = 9, /*!< The pickled object is too new */
OLM_CORRUPTED_PICKLE = 10, /*!< The pickled object couldn't be decoded */
+ OLM_BAD_RATCHET_KEY = 11,
+ OLM_BAD_CHAIN_INDEX = 12,
+
/* remember to update the list of string constants in error.c when updating
* this list. */
};
diff --git a/include/olm/inbound_group_session.h b/include/olm/inbound_group_session.h
new file mode 100644
index 0000000..4cf4ac4
--- /dev/null
+++ b/include/olm/inbound_group_session.h
@@ -0,0 +1,153 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef OLM_INBOUND_GROUP_SESSION_H_
+#define OLM_INBOUND_GROUP_SESSION_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+typedef struct OlmInboundGroupSession OlmInboundGroupSession;
+
+/** get the size of an inbound group session, in bytes. */
+size_t olm_inbound_group_session_size();
+
+/**
+ * Initialise an inbound group session object using the supplied memory
+ * The supplied memory should be at least olm_inbound_group_session_size()
+ * bytes.
+ */
+OlmInboundGroupSession * olm_inbound_group_session(
+ void *memory
+);
+
+/**
+ * A null terminated string describing the most recent error to happen to a
+ * group session */
+const char *olm_inbound_group_session_last_error(
+ const OlmInboundGroupSession *session
+);
+
+/** Clears the memory used to back this group session */
+size_t olm_clear_inbound_group_session(
+ OlmInboundGroupSession *session
+);
+
+/** Returns the number of bytes needed to store an inbound group session */
+size_t olm_pickle_inbound_group_session_length(
+ const OlmInboundGroupSession *session
+);
+
+/**
+ * Stores a group session as a base64 string. Encrypts the session using the
+ * supplied key. Returns the length of the session on success.
+ *
+ * Returns olm_error() on failure. If the pickle output buffer
+ * is smaller than olm_pickle_inbound_group_session_length() then
+ * olm_inbound_group_session_last_error() will be "OUTPUT_BUFFER_TOO_SMALL"
+ */
+size_t olm_pickle_inbound_group_session(
+ OlmInboundGroupSession *session,
+ void const * key, size_t key_length,
+ void * pickled, size_t pickled_length
+);
+
+/**
+ * Loads a group session from a pickled base64 string. Decrypts the session
+ * using the supplied key.
+ *
+ * Returns olm_error() on failure. If the key doesn't match the one used to
+ * encrypt the account then olm_inbound_group_session_last_error() will be
+ * "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then
+ * olm_inbound_group_session_last_error() will be "INVALID_BASE64". The input
+ * pickled buffer is destroyed
+ */
+size_t olm_unpickle_inbound_group_session(
+ OlmInboundGroupSession *session,
+ void const * key, size_t key_length,
+ void * pickled, size_t pickled_length
+);
+
+
+/**
+ * Start a new inbound group session, based on the parameters supplied.
+ *
+ * Returns olm_error() on failure. On failure last_error will be set with an
+ * error code. The last_error will be:
+ *
+ * * OLM_INVALID_BASE64 if the session_key is not valid base64
+ * * OLM_BAD_RATCHET_KEY if the session_key is invalid
+ */
+size_t olm_init_inbound_group_session(
+ OlmInboundGroupSession *session,
+ uint32_t message_index,
+
+ /* base64-encoded key */
+ uint8_t const * session_key, size_t session_key_length
+);
+
+/**
+ * Get an upper bound on the number of bytes of plain-text the decrypt method
+ * will write for a given input message length. The actual size could be
+ * different due to padding.
+ *
+ * The input message buffer is destroyed.
+ *
+ * Returns olm_error() on failure.
+ */
+size_t olm_group_decrypt_max_plaintext_length(
+ OlmInboundGroupSession *session,
+ uint8_t * message, size_t message_length
+);
+
+/**
+ * Decrypt a message.
+ *
+ * The input message buffer is destroyed.
+ *
+ * Returns the length of the decrypted plain-text, or olm_error() on failure.
+ *
+ * On failure last_error will be set with an error code. The last_error will
+ * be:
+ * * OLM_OUTPUT_BUFFER_TOO_SMALL if the plain-text buffer is too small
+ * * OLM_INVALID_BASE64 if the message is not valid base-64
+ * * OLM_BAD_MESSAGE_VERSION if the message was encrypted with an unsupported
+ * version of the protocol
+ * * OLM_BAD_MESSAGE_FORMAT if the message headers could not be decoded
+ * * OLM_BAD_MESSAGE_MAC if the message could not be verified
+ * * OLM_BAD_CHAIN_INDEX if we do not have a ratchet key corresponding to the
+ * message's index (ie, it was sent before the ratchet key was shared with
+ * us)
+ */
+size_t olm_group_decrypt(
+ OlmInboundGroupSession *session,
+
+ /* input; note that it will be overwritten with the base64-decoded
+ message. */
+ uint8_t * message, size_t message_length,
+
+ /* output */
+ uint8_t * plaintext, size_t max_plaintext_length
+);
+
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif /* OLM_INBOUND_GROUP_SESSION_H_ */
diff --git a/include/olm/message.h b/include/olm/message.h
index 05fb56c..bd7aec3 100644
--- a/include/olm/message.h
+++ b/include/olm/message.h
@@ -65,6 +65,30 @@ void _olm_encode_group_message(
);
+struct _OlmDecodeGroupMessageResults {
+ uint8_t version;
+ const uint8_t *session_id;
+ size_t session_id_length;
+ uint32_t chain_index;
+ int has_chain_index;
+ const uint8_t *ciphertext;
+ size_t ciphertext_length;
+};
+
+
+/**
+ * Reads the message headers from the input buffer.
+ */
+void _olm_decode_group_message(
+ const uint8_t *input, size_t input_length,
+ size_t mac_length,
+
+ /* output structure: updated with results */
+ struct _OlmDecodeGroupMessageResults *results
+);
+
+
+
#ifdef __cplusplus
} // extern "C"
#endif
diff --git a/include/olm/olm.h b/include/olm/olm.h
index 00e1f63..dbaf71e 100644
--- a/include/olm/olm.h
+++ b/include/olm/olm.h
@@ -19,6 +19,7 @@
#include <stddef.h>
#include <stdint.h>
+#include "olm/inbound_group_session.h"
#include "olm/outbound_group_session.h"
#ifdef __cplusplus
diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c
new file mode 100644
index 0000000..4796414
--- /dev/null
+++ b/src/inbound_group_session.c
@@ -0,0 +1,199 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "olm/inbound_group_session.h"
+
+#include <string.h>
+
+#include "olm/base64.h"
+#include "olm/cipher.h"
+#include "olm/error.h"
+#include "olm/megolm.h"
+#include "olm/message.h"
+
+#define OLM_PROTOCOL_VERSION 3
+
+struct OlmInboundGroupSession {
+ /** our earliest known ratchet value */
+ Megolm initial_ratchet;
+
+ /** The most recent ratchet value */
+ Megolm latest_ratchet;
+
+ enum OlmErrorCode last_error;
+};
+
+size_t olm_inbound_group_session_size() {
+ return sizeof(OlmInboundGroupSession);
+}
+
+OlmInboundGroupSession * olm_inbound_group_session(
+ void *memory
+) {
+ OlmInboundGroupSession *session = memory;
+ olm_clear_inbound_group_session(session);
+ return session;
+}
+
+const char *olm_inbound_group_session_last_error(
+ const OlmInboundGroupSession *session
+) {
+ return _olm_error_to_string(session->last_error);
+}
+
+size_t olm_clear_inbound_group_session(
+ OlmInboundGroupSession *session
+) {
+ memset(session, 0, sizeof(OlmInboundGroupSession));
+ return sizeof(OlmInboundGroupSession);
+}
+
+size_t olm_init_inbound_group_session(
+ OlmInboundGroupSession *session,
+ uint32_t message_index,
+ const uint8_t * session_key, size_t session_key_length
+) {
+ uint8_t key_buf[MEGOLM_RATCHET_LENGTH];
+ size_t raw_length = _olm_decode_base64_length(session_key_length);
+
+ if (raw_length == (size_t)-1) {
+ session->last_error = OLM_INVALID_BASE64;
+ return (size_t)-1;
+ }
+
+ if (raw_length != MEGOLM_RATCHET_LENGTH) {
+ session->last_error = OLM_BAD_RATCHET_KEY;
+ return (size_t)-1;
+ }
+
+ _olm_decode_base64(session_key, session_key_length, key_buf);
+ megolm_init(&session->initial_ratchet, key_buf, message_index);
+ megolm_init(&session->latest_ratchet, key_buf, message_index);
+ memset(key_buf, 0, MEGOLM_RATCHET_LENGTH);
+
+ return 0;
+}
+
+size_t olm_group_decrypt_max_plaintext_length(
+ OlmInboundGroupSession *session,
+ uint8_t * message, size_t message_length
+) {
+ size_t r;
+ const struct _olm_cipher *cipher = megolm_cipher();
+ struct _OlmDecodeGroupMessageResults decoded_results;
+
+ r = _olm_decode_base64(message, message_length, message);
+ if (r == (size_t)-1) {
+ session->last_error = OLM_INVALID_BASE64;
+ return r;
+ }
+
+ _olm_decode_group_message(
+ message, message_length,
+ cipher->ops->mac_length(cipher),
+ &decoded_results);
+
+ if (decoded_results.version != OLM_PROTOCOL_VERSION) {
+ session->last_error = OLM_BAD_MESSAGE_VERSION;
+ return (size_t)-1;
+ }
+
+ if (!decoded_results.ciphertext) {
+ session->last_error = OLM_BAD_MESSAGE_FORMAT;
+ return (size_t)-1;
+ }
+
+ return cipher->ops->decrypt_max_plaintext_length(
+ cipher, decoded_results.ciphertext_length);
+}
+
+
+size_t olm_group_decrypt(
+ OlmInboundGroupSession *session,
+ uint8_t * message, size_t message_length,
+ uint8_t * plaintext, size_t max_plaintext_length
+) {
+ struct _OlmDecodeGroupMessageResults decoded_results;
+ const struct _olm_cipher *cipher = megolm_cipher();
+ size_t max_length, raw_message_length, r;
+ Megolm *megolm;
+ Megolm tmp_megolm;
+
+ raw_message_length = _olm_decode_base64(message, message_length, message);
+ if (raw_message_length == (size_t)-1) {
+ session->last_error = OLM_INVALID_BASE64;
+ return (size_t)-1;
+ }
+
+ _olm_decode_group_message(
+ message, raw_message_length,
+ cipher->ops->mac_length(cipher),
+ &decoded_results);
+
+ if (decoded_results.version != OLM_PROTOCOL_VERSION) {
+ session->last_error = OLM_BAD_MESSAGE_VERSION;
+ return (size_t)-1;
+ }
+
+ if (!decoded_results.has_chain_index || !decoded_results.session_id
+ || !decoded_results.ciphertext
+ ) {
+ session->last_error = OLM_BAD_MESSAGE_FORMAT;
+ return (size_t)-1;
+ }
+
+ max_length = cipher->ops->decrypt_max_plaintext_length(
+ cipher,
+ decoded_results.ciphertext_length
+ );
+ if (max_plaintext_length < max_length) {
+ session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return (size_t)-1;
+ }
+
+ /* pick a megolm instance to use. If we're at or beyond the latest ratchet
+ * value, use that */
+ if ((int32_t)(decoded_results.chain_index - session->latest_ratchet.counter) >= 0) {
+ megolm = &session->latest_ratchet;
+ } else if ((int32_t)(decoded_results.chain_index - session->initial_ratchet.counter) < 0) {
+ /* the counter is before our intial ratchet - we can't decode this. */
+ session->last_error = OLM_BAD_CHAIN_INDEX;
+ return (size_t)-1;
+ } else {
+ /* otherwise, start from the initial megolm. Take a copy so that we
+ * don't overwrite the initial megolm */
+ tmp_megolm = session->initial_ratchet;
+ megolm = &tmp_megolm;
+ }
+
+ megolm_advance_to(megolm, decoded_results.chain_index);
+
+ /* now try checking the mac, and decrypting */
+ r = cipher->ops->decrypt(
+ cipher,
+ megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH,
+ message, raw_message_length,
+ decoded_results.ciphertext, decoded_results.ciphertext_length,
+ plaintext, max_plaintext_length
+ );
+
+ memset(&tmp_megolm, 0, sizeof(tmp_megolm));
+ if (r == (size_t)-1) {
+ session->last_error = OLM_BAD_MESSAGE_MAC;
+ return r;
+ }
+
+ return r;
+}
diff --git a/src/message.cpp b/src/message.cpp
index df0c7bb..ec44262 100644
--- a/src/message.cpp
+++ b/src/message.cpp
@@ -363,3 +363,45 @@ void _olm_encode_group_message(
pos = encode(pos, COUNTER_TAG, chain_index);
pos = encode(pos, CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length);
}
+
+void _olm_decode_group_message(
+ const uint8_t *input, size_t input_length,
+ size_t mac_length,
+ struct _OlmDecodeGroupMessageResults *results
+) {
+ std::uint8_t const * pos = input;
+ std::uint8_t const * end = input + input_length - mac_length;
+ std::uint8_t const * unknown = nullptr;
+
+ results->session_id = nullptr;
+ results->session_id_length = 0;
+ bool has_chain_index = false;
+ results->chain_index = 0;
+ results->ciphertext = nullptr;
+ results->ciphertext_length = 0;
+
+ if (pos == end) return;
+ if (input_length < mac_length) return;
+ results->version = *(pos++);
+
+ while (pos != end) {
+ pos = decode(
+ pos, end, GROUP_SESSION_ID_TAG,
+ results->session_id, results->session_id_length
+ );
+ pos = decode(
+ pos, end, COUNTER_TAG,
+ results->chain_index, has_chain_index
+ );
+ pos = decode(
+ pos, end, CIPHERTEXT_TAG,
+ results->ciphertext, results->ciphertext_length
+ );
+ if (unknown == pos) {
+ pos = skip_unknown(pos, end);
+ }
+ unknown = pos;
+ }
+
+ results->has_chain_index = (int)has_chain_index;
+}
diff --git a/tests/test_group_session.cpp b/tests/test_group_session.cpp
index b9fe1ef..5bbdc9d 100644
--- a/tests/test_group_session.cpp
+++ b/tests/test_group_session.cpp
@@ -12,6 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+#include "olm/inbound_group_session.h"
#include "olm/outbound_group_session.h"
#include "unittest.hh"
@@ -19,11 +20,10 @@
int main() {
{
-
TestCase test_case("Pickle outbound group");
size_t size = olm_outbound_group_session_size();
- void *memory = alloca(size);
+ uint8_t memory[size];
OlmOutboundGroupSession *session = olm_outbound_group_session(memory);
size_t pickle_length = olm_pickle_outbound_group_session_length(session);
@@ -61,9 +61,9 @@ int main() {
"0123456789ABDEF0123456789ABCDEF";
-
+ /* build the outbound session */
size_t size = olm_outbound_group_session_size();
- void *memory = alloca(size);
+ uint8_t memory[size];
OlmOutboundGroupSession *session = olm_outbound_group_session(memory);
assert_equals((size_t)132,
@@ -73,18 +73,48 @@ int main() {
session, random_bytes, sizeof(random_bytes));
assert_equals((size_t)0, res);
+ assert_equals(0U, olm_outbound_group_session_message_index(session));
+ size_t session_key_len = olm_outbound_group_session_key_length(session);
+ uint8_t session_key[session_key_len];
+ olm_outbound_group_session_key(session, session_key, session_key_len);
+
+
+ /* encode the message */
uint8_t plaintext[] = "Message";
size_t plaintext_length = sizeof(plaintext) - 1;
size_t msglen = olm_group_encrypt_message_length(
session, plaintext_length);
- uint8_t *msg = (uint8_t *)alloca(msglen);
+ uint8_t msg[msglen];
res = olm_group_encrypt(session, plaintext, plaintext_length,
msg, msglen);
assert_equals(msglen, res);
+ assert_equals(1U, olm_outbound_group_session_message_index(session));
+
+
+ /* build the inbound session */
+ size = olm_inbound_group_session_size();
+ uint8_t inbound_session_memory[size];
+ OlmInboundGroupSession *inbound_session =
+ olm_inbound_group_session(inbound_session_memory);
+
+ res = olm_init_inbound_group_session(
+ inbound_session, 0U, session_key, session_key_len);
+ assert_equals((size_t)0, res);
- // TODO: decode the message
+ /* decode the message */
+
+ /* olm_group_decrypt_max_plaintext_length destroys the input so we have to
+ copy it. */
+ uint8_t msgcopy[msglen];
+ memcpy(msgcopy, msg, msglen);
+ size = olm_group_decrypt_max_plaintext_length(inbound_session, msgcopy, msglen);
+ uint8_t plaintext_buf[size];
+ res = olm_group_decrypt(inbound_session, msg, msglen,
+ plaintext_buf, size);
+ assert_equals(plaintext_length, res);
+ assert_equals(plaintext, plaintext_buf, res);
}
}
diff --git a/tests/test_message.cpp b/tests/test_message.cpp
index e2385ea..5fec9e0 100644
--- a/tests/test_message.cpp
+++ b/tests/test_message.cpp
@@ -97,4 +97,26 @@ assert_equals(message2, output, 35);
assert_equals(output+sizeof(expected)-1, ciphertext_ptr);
} /* group message encode test */
+{
+ TestCase test_case("Group message decode test");
+
+ struct _OlmDecodeGroupMessageResults results;
+ std::uint8_t message[] =
+ "\x03"
+ "\x2A\x09sessionid"
+ "\x10\xc8\x01"
+ "\x22\x0A" "ciphertext"
+ "hmacsha2";
+
+ const uint8_t expected_session_id[] = "sessionid";
+
+ _olm_decode_group_message(message, sizeof(message)-1, 8, &results);
+ assert_equals(std::uint8_t(3), results.version);
+ assert_equals(std::size_t(9), results.session_id_length);
+ assert_equals(expected_session_id, results.session_id, 9);
+ assert_equals(1, results.has_chain_index);
+ assert_equals(std::uint32_t(200), results.chain_index);
+ assert_equals(std::size_t(10), results.ciphertext_length);
+ assert_equals(ciphertext, results.ciphertext, 10);
+} /* group message decode test */
}