aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2016-05-24 16:23:19 +0100
committerRichard van der Hoff <richard@matrix.org>2016-05-24 16:23:19 +0100
commit1b15465c42a88f750a960a0e73f186245f9bba33 (patch)
treee0ecb976022b606e2872c490718ad5dd182be0d9
parenta919a149fbb192e3fae7aba921ca28e02d9c0d10 (diff)
Separate base64ing from the rest of msg encoding
Factor the actual message encoding/decoding and encrypting/decrypting out to separate functions from the top-level functions which do the base64-wrangling. This is particularly helpful in the 'outbound' code-path where the offsets required to allow room to base64-encode make the flow hard to see when it's all inline.
-rw-r--r--include/olm/message.h4
-rw-r--r--src/inbound_group_session.c64
-rw-r--r--src/message.cpp3
-rw-r--r--src/outbound_group_session.c65
4 files changed, 94 insertions, 42 deletions
diff --git a/include/olm/message.h b/include/olm/message.h
index cff15f3..e80d54c 100644
--- a/include/olm/message.h
+++ b/include/olm/message.h
@@ -53,8 +53,10 @@ size_t _olm_encode_group_message_length(
* olm_encode_group_message_length() bytes long.
* ciphertext_ptr: returns the address that the ciphertext
* should be written to, followed by the MAC.
+ *
+ * Returns the size of the message, up to the MAC.
*/
-void _olm_encode_group_message(
+size_t _olm_encode_group_message(
uint8_t version,
const uint8_t *session_id,
size_t session_id_length,
diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c
index b6894c1..e171205 100644
--- a/src/inbound_group_session.c
+++ b/src/inbound_group_session.c
@@ -163,19 +163,15 @@ size_t olm_unpickle_inbound_group_session(
return pickled_length;
}
-size_t olm_group_decrypt_max_plaintext_length(
+/**
+ * get the max plaintext length in an un-base64-ed message
+ */
+static size_t _decrypt_max_plaintext_length(
OlmInboundGroupSession *session,
uint8_t * message, size_t message_length
) {
- size_t r;
struct _OlmDecodeGroupMessageResults decoded_results;
- r = _olm_decode_base64(message, message_length, message);
- if (r == (size_t)-1) {
- session->last_error = OLM_INVALID_BASE64;
- return r;
- }
-
_olm_decode_group_message(
message, message_length,
megolm_cipher->ops->mac_length(megolm_cipher),
@@ -195,25 +191,38 @@ size_t olm_group_decrypt_max_plaintext_length(
megolm_cipher, decoded_results.ciphertext_length);
}
+size_t olm_group_decrypt_max_plaintext_length(
+ OlmInboundGroupSession *session,
+ uint8_t * message, size_t message_length
+) {
+ size_t raw_length;
-size_t olm_group_decrypt(
+ raw_length = _olm_decode_base64(message, message_length, message);
+ if (raw_length == (size_t)-1) {
+ session->last_error = OLM_INVALID_BASE64;
+ return (size_t)-1;
+ }
+
+ return _decrypt_max_plaintext_length(
+ session, message, raw_length
+ );
+}
+
+/**
+ * decrypt an un-base64-ed message
+ */
+static size_t _decrypt(
OlmInboundGroupSession *session,
uint8_t * message, size_t message_length,
uint8_t * plaintext, size_t max_plaintext_length
) {
struct _OlmDecodeGroupMessageResults decoded_results;
- size_t max_length, raw_message_length, r;
+ size_t max_length, r;
Megolm *megolm;
Megolm tmp_megolm;
- raw_message_length = _olm_decode_base64(message, message_length, message);
- if (raw_message_length == (size_t)-1) {
- session->last_error = OLM_INVALID_BASE64;
- return (size_t)-1;
- }
-
_olm_decode_group_message(
- message, raw_message_length,
+ message, message_length,
megolm_cipher->ops->mac_length(megolm_cipher),
&decoded_results);
@@ -259,7 +268,7 @@ size_t olm_group_decrypt(
r = megolm_cipher->ops->decrypt(
megolm_cipher,
megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH,
- message, raw_message_length,
+ message, message_length,
decoded_results.ciphertext, decoded_results.ciphertext_length,
plaintext, max_plaintext_length
);
@@ -272,3 +281,22 @@ size_t olm_group_decrypt(
return r;
}
+
+size_t olm_group_decrypt(
+ OlmInboundGroupSession *session,
+ uint8_t * message, size_t message_length,
+ uint8_t * plaintext, size_t max_plaintext_length
+) {
+ size_t raw_message_length;
+
+ raw_message_length = _olm_decode_base64(message, message_length, message);
+ if (raw_message_length == (size_t)-1) {
+ session->last_error = OLM_INVALID_BASE64;
+ return (size_t)-1;
+ }
+
+ return _decrypt(
+ session, message, raw_message_length,
+ plaintext, max_plaintext_length
+ );
+}
diff --git a/src/message.cpp b/src/message.cpp
index ab4300e..2e841e5 100644
--- a/src/message.cpp
+++ b/src/message.cpp
@@ -347,7 +347,7 @@ size_t _olm_encode_group_message_length(
}
-void _olm_encode_group_message(
+size_t _olm_encode_group_message(
uint8_t version,
const uint8_t *session_id,
size_t session_id_length,
@@ -364,6 +364,7 @@ void _olm_encode_group_message(
std::memcpy(session_id_pos, session_id, session_id_length);
pos = encode(pos, GROUP_MESSAGE_INDEX_TAG, message_index);
pos = encode(pos, GROUP_CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length);
+ return pos-output;
}
void _olm_decode_group_message(
diff --git a/src/outbound_group_session.c b/src/outbound_group_session.c
index 9f36ad8..9b2298a 100644
--- a/src/outbound_group_session.c
+++ b/src/outbound_group_session.c
@@ -199,51 +199,41 @@ size_t olm_group_encrypt_message_length(
return _olm_encode_base64_length(message_length);
}
-
-size_t olm_group_encrypt(
- OlmOutboundGroupSession *session,
- uint8_t const * plaintext, size_t plaintext_length,
- uint8_t * message, size_t max_message_length
+/** write an un-base64-ed message to the buffer */
+static size_t _encrypt(
+ OlmOutboundGroupSession *session, uint8_t const * plaintext, size_t plaintext_length,
+ uint8_t * buffer
) {
- size_t ciphertext_length;
- size_t rawmsglen;
+ size_t ciphertext_length, mac_length, message_length;
size_t result;
- uint8_t *ciphertext_ptr, *message_pos;
-
- rawmsglen = raw_message_length(session, plaintext_length);
-
- if (max_message_length < _olm_encode_base64_length(rawmsglen)) {
- session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
- return (size_t)-1;
- }
+ uint8_t *ciphertext_ptr;
ciphertext_length = megolm_cipher->ops->encrypt_ciphertext_length(
megolm_cipher,
plaintext_length
);
- /* we construct the message at the end of the buffer, so that
- * we have room to base64-encode it once we're done.
- */
- message_pos = message + _olm_encode_base64_length(rawmsglen) - rawmsglen;
+ mac_length = megolm_cipher->ops->mac_length(megolm_cipher);
/* first we build the message structure, then we encrypt
* the plaintext into it.
*/
- _olm_encode_group_message(
+ message_length = _olm_encode_group_message(
OLM_PROTOCOL_VERSION,
session->session_id, GROUP_SESSION_ID_LENGTH,
session->ratchet.counter,
ciphertext_length,
- message_pos,
+ buffer,
&ciphertext_ptr);
+ message_length += mac_length;
+
result = megolm_cipher->ops->encrypt(
megolm_cipher,
megolm_get_data(&(session->ratchet)), MEGOLM_RATCHET_LENGTH,
plaintext, plaintext_length,
ciphertext_ptr, ciphertext_length,
- message_pos, rawmsglen
+ buffer, message_length
);
if (result == (size_t)-1) {
@@ -252,6 +242,37 @@ size_t olm_group_encrypt(
megolm_advance(&(session->ratchet));
+ return result;
+}
+
+size_t olm_group_encrypt(
+ OlmOutboundGroupSession *session,
+ uint8_t const * plaintext, size_t plaintext_length,
+ uint8_t * message, size_t max_message_length
+) {
+ size_t rawmsglen;
+ size_t result;
+ uint8_t *message_pos;
+
+ rawmsglen = raw_message_length(session, plaintext_length);
+
+ if (max_message_length < _olm_encode_base64_length(rawmsglen)) {
+ session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return (size_t)-1;
+ }
+
+ /* we construct the message at the end of the buffer, so that
+ * we have room to base64-encode it once we're done.
+ */
+ message_pos = message + _olm_encode_base64_length(rawmsglen) - rawmsglen;
+
+ /* write the message, and encrypt it, at message_pos */
+ result = _encrypt(session, plaintext, plaintext_length, message_pos);
+ if (result == (size_t)-1) {
+ return result;
+ }
+
+ /* bas64-encode it */
return _olm_encode_base64(
message_pos, rawmsglen, message
);