From 68d3c7bfa9d0d2f8a44edcd2d277c4a516ed6ed5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 27 Apr 2016 17:16:55 +0100 Subject: Implementation of the megolm ratchet --- tests/test_megolm.cpp | 85 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 tests/test_megolm.cpp (limited to 'tests') diff --git a/tests/test_megolm.cpp b/tests/test_megolm.cpp new file mode 100644 index 0000000..871de36 --- /dev/null +++ b/tests/test_megolm.cpp @@ -0,0 +1,85 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "olm/megolm.h" +#include "olm/memory.hh" + +#include "unittest.hh" + + +int main() { + +std::uint8_t random_bytes[] = + "0123456789ABCDEF0123456789ABCDEF" + "0123456789ABCDEF0123456789ABCDEF" + "0123456789ABCDEF0123456789ABCDEF" + "0123456789ABCDEF0123456789ABCDEF"; + +{ + TestCase test_case("Megolm::advance"); + + Megolm mr; + + megolm_init(&mr, random_bytes, 0); + // single-step advance + megolm_advance(&mr); + const std::uint8_t expected1[] = { + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0xba, 0x9c, 0xd9, 0x55, 0x74, 0x1d, 0x1c, 0x16, 0x23, 0x23, 0xec, 0x82, 0x5e, 0x7c, 0x5c, 0xe8, + 0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a + }; + assert_equals(1U, mr.counter); + assert_equals(expected1, megolm_get_data(&mr), MEGOLM_RATCHET_LENGTH); + + // repeat with complex advance + megolm_init(&mr, random_bytes, 0); + megolm_advance_to(&mr, 1); + assert_equals(1U, mr.counter); + assert_equals(expected1, megolm_get_data(&mr), MEGOLM_RATCHET_LENGTH); + + megolm_advance_to(&mr, 0x1000000); + const std::uint8_t expected2[] = { + 0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9, + 0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c, + 0x70, 0x04, 0xc0, 0x1e, 0xe4, 0x9b, 0xd6, 0xef, 0xe0, 0x07, 0x35, 0x25, 0xaf, 0x9b, 0x16, 0x32, + 0xc5, 0xbe, 0x72, 0x6d, 0x12, 0x34, 0x9c, 0xc5, 0xbd, 0x47, 0x2b, 0xdc, 0x2d, 0xf6, 0x54, 0x0f, + 0x31, 0x12, 0x59, 0x11, 0x94, 0xfd, 0xa6, 0x17, 0xe5, 0x68, 0xc6, 0x83, 0x10, 0x1e, 0xae, 0xcd, + 0x7e, 0xdd, 0xd6, 0xde, 0x1f, 0xbc, 0x07, 0x67, 0xae, 0x34, 0xda, 0x1a, 0x09, 0xa5, 0x4e, 0xab, + 0xba, 0x9c, 0xd9, 0x55, 0x74, 0x1d, 0x1c, 0x16, 0x23, 0x23, 0xec, 0x82, 0x5e, 0x7c, 0x5c, 0xe8, + 0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a, + }; + assert_equals(0x1000000U, mr.counter); + assert_equals(expected2, megolm_get_data(&mr), MEGOLM_RATCHET_LENGTH); + + megolm_advance_to(&mr, 0x1041506); + const std::uint8_t expected3[] = { + 0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9, + 0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c, + 0x55, 0x58, 0x8d, 0xf5, 0xb7, 0xa4, 0x88, 0x78, 0x42, 0x89, 0x27, 0x86, 0x81, 0x64, 0x58, 0x9f, + 0x36, 0x63, 0x44, 0x7b, 0x51, 0xed, 0xc3, 0x59, 0x5b, 0x03, 0x6c, 0xa6, 0x04, 0xc4, 0x6d, 0xcd, + 0x5c, 0x54, 0x85, 0x0b, 0xfa, 0x98, 0xa1, 0xfd, 0x79, 0xa9, 0xdf, 0x1c, 0xbe, 0x8f, 0xc5, 0x68, + 0x19, 0x37, 0xd3, 0x0c, 0x85, 0xc8, 0xc3, 0x1f, 0x7b, 0xb8, 0x28, 0x81, 0x6c, 0xf9, 0xff, 0x3b, + 0x95, 0x6c, 0xbf, 0x80, 0x7e, 0x65, 0x12, 0x6a, 0x49, 0x55, 0x8d, 0x45, 0xc8, 0x4a, 0x2e, 0x4c, + 0xd5, 0x6f, 0x03, 0xe2, 0x44, 0x16, 0xb9, 0x8e, 0x1c, 0xfd, 0x97, 0xc2, 0x06, 0xaa, 0x90, 0x7a + }; + assert_equals(0x1041506U, mr.counter); + assert_equals(expected3, megolm_get_data(&mr), MEGOLM_RATCHET_LENGTH); +} + +} -- cgit v1.2.3-70-g09d2 From caaed796ad54de3f8ee1e56123973ae9ace346b9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 17 May 2016 11:52:06 +0100 Subject: Implementation of an outbound group session --- include/olm/error.h | 1 - include/olm/megolm.h | 8 ++ include/olm/message.h | 72 ++++++++++++++ include/olm/message.hh | 12 +++ include/olm/olm.h | 2 + include/olm/outbound_group_session.h | 90 +++++++++++++++++ src/megolm.c | 14 +++ src/message.cpp | 40 +++++++- src/outbound_group_session.c | 183 +++++++++++++++++++++++++++++++++++ tests/test_group_session.cpp | 56 +++++++++++ tests/test_message.cpp | 37 ++++++- 11 files changed, 512 insertions(+), 3 deletions(-) create mode 100644 include/olm/message.h create mode 100644 include/olm/outbound_group_session.h create mode 100644 src/outbound_group_session.c create mode 100644 tests/test_group_session.cpp (limited to 'tests') diff --git a/include/olm/error.h b/include/olm/error.h index 460017e..87e019a 100644 --- a/include/olm/error.h +++ b/include/olm/error.h @@ -34,7 +34,6 @@ enum OlmErrorCode { /* remember to update the list of string constants in error.c when updating * this list. */ - }; /** get a string representation of the given error code. */ diff --git a/include/olm/megolm.h b/include/olm/megolm.h index 784597e..5cae353 100644 --- a/include/olm/megolm.h +++ b/include/olm/megolm.h @@ -47,6 +47,14 @@ typedef struct Megolm { uint32_t counter; } Megolm; + +/** + * Get the cipher used in megolm-backed conversations + * + * (AES256 + SHA256, with keys based on an HKDF with info of MEGOLM_KEYS) + */ +const struct _olm_cipher *megolm_cipher(); + /** * initialize the megolm ratchet. random_data should be at least * MEGOLM_RATCHET_LENGTH bytes of randomness. diff --git a/include/olm/message.h b/include/olm/message.h new file mode 100644 index 0000000..05fb56c --- /dev/null +++ b/include/olm/message.h @@ -0,0 +1,72 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * functions for encoding and decoding messages in the Olm protocol. + * + * Some of these functions have only C++ bindings, and are declared in + * message.hh; in time, they should probably be converted to plain C and + * declared here. + */ + +#ifndef OLM_MESSAGE_H_ +#define OLM_MESSAGE_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * The length of the buffer needed to hold a group message. + */ +size_t _olm_encode_group_message_length( + size_t group_session_id_length, + uint32_t chain_index, + size_t ciphertext_length, + size_t mac_length +); + +/** + * Writes the message headers into the output buffer. + * + * version: version number of the olm protocol + * session_id: group session identifier + * session_id_length: length of session_id + * chain_index: message index + * ciphertext_length: length of the ciphertext + * output: where to write the output. Should be at least + * olm_encode_group_message_length() bytes long. + * ciphertext_ptr: returns the address that the ciphertext + * should be written to, followed by the MAC. + */ +void _olm_encode_group_message( + uint8_t version, + const uint8_t *session_id, + size_t session_id_length, + uint32_t chain_index, + size_t ciphertext_length, + uint8_t *output, + uint8_t **ciphertext_ptr +); + + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* OLM_MESSAGE_H_ */ diff --git a/include/olm/message.hh b/include/olm/message.hh index 5ce0a62..bd912d9 100644 --- a/include/olm/message.hh +++ b/include/olm/message.hh @@ -12,6 +12,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + + +/** + * functions for encoding and decoding messages in the Olm protocol. + * + * Some of these functions have plain-C bindings, and are declared in + * message.h; in time, all of the functions declared here should probably be + * converted to plain C and moved to message.h. + */ + +#include "message.h" + #include #include diff --git a/include/olm/olm.h b/include/olm/olm.h index 8abac49..00e1f63 100644 --- a/include/olm/olm.h +++ b/include/olm/olm.h @@ -19,6 +19,8 @@ #include #include +#include "olm/outbound_group_session.h" + #ifdef __cplusplus extern "C" { #endif diff --git a/include/olm/outbound_group_session.h b/include/olm/outbound_group_session.h new file mode 100644 index 0000000..6c02370 --- /dev/null +++ b/include/olm/outbound_group_session.h @@ -0,0 +1,90 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef OLM_OUTBOUND_GROUP_SESSION_H_ +#define OLM_OUTBOUND_GROUP_SESSION_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct OlmOutboundGroupSession OlmOutboundGroupSession; + +/** get the size of an outbound group session, in bytes. */ +size_t olm_outbound_group_session_size(); + +/** + * Initialise an outbound group session object using the supplied memory + * The supplied memory should be at least olm_outbound_group_session_size() + * bytes. + */ +OlmOutboundGroupSession * olm_outbound_group_session( + void *memory +); + +/** + * A null terminated string describing the most recent error to happen to a + * group session */ +const char *olm_outbound_group_session_last_error( + const OlmOutboundGroupSession *session +); + +/** Clears the memory used to back this group session */ +size_t olm_clear_outbound_group_session( + OlmOutboundGroupSession *session +); + +/** The number of random bytes needed to create an outbound group session */ +size_t olm_init_outbound_group_session_random_length( + const OlmOutboundGroupSession *session +); + +/** + * Start a new outbound group session. Returns std::size_t(-1) on failure. On + * failure last_error will be set with an error code. The last_error will be + * NOT_ENOUGH_RANDOM if the number of random bytes was too small. + */ +size_t olm_init_outbound_group_session( + OlmOutboundGroupSession *session, + uint8_t const * random, size_t random_length +); + +/** + * The number of bytes that will be created by encrypting a message + */ +size_t olm_group_encrypt_message_length( + OlmOutboundGroupSession *session, + size_t plaintext_length +); + +/** + * Encrypt some plain-text. Returns the length of the encrypted message or + * std::size_t(-1) on failure. On failure last_error will be set with an + * error code. The last_error will be OUTPUT_BUFFER_TOO_SMALL if the output + * buffer is too small. + */ +size_t olm_group_encrypt( + OlmOutboundGroupSession *session, + uint8_t const * plaintext, size_t plaintext_length, + uint8_t * message, size_t message_length +); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* OLM_OUTBOUND_GROUP_SESSION_H_ */ diff --git a/src/megolm.c b/src/megolm.c index 36b0cc2..58fe725 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -18,8 +18,22 @@ #include +#include "olm/cipher.h" #include "olm/crypto.h" +const struct _olm_cipher *megolm_cipher() { + static const uint8_t CIPHER_KDF_INFO[] = "MEGOLM_KEYS"; + static struct _olm_cipher *cipher; + static struct _olm_cipher_aes_sha_256 OLM_CIPHER; + if (!cipher) { + cipher = _olm_cipher_aes_sha_256_init( + &OLM_CIPHER, + CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) - 1 + ); + } + return cipher; +} + /* the seeds used in the HMAC-SHA-256 functions for each part of the ratchet. */ #define HASH_KEY_SEED_LENGTH 1 diff --git a/src/message.cpp b/src/message.cpp index 3be5234..df0c7bb 100644 --- a/src/message.cpp +++ b/src/message.cpp @@ -1,4 +1,4 @@ -/* Copyright 2015 OpenMarket Ltd +/* Copyright 2015-2016 OpenMarket Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -325,3 +325,41 @@ void olm::decode_one_time_key_message( unknown = pos; } } + + + +static std::uint8_t const GROUP_SESSION_ID_TAG = 052; + +size_t _olm_encode_group_message_length( + size_t group_session_id_length, + uint32_t chain_index, + size_t ciphertext_length, + size_t mac_length +) { + size_t length = VERSION_LENGTH; + length += 1 + varstring_length(group_session_id_length); + length += 1 + varint_length(chain_index); + length += 1 + varstring_length(ciphertext_length); + length += mac_length; + return length; +} + + +void _olm_encode_group_message( + uint8_t version, + const uint8_t *session_id, + size_t session_id_length, + uint32_t chain_index, + size_t ciphertext_length, + uint8_t *output, + uint8_t **ciphertext_ptr +) { + std::uint8_t * pos = output; + std::uint8_t * session_id_pos; + + *(pos++) = version; + pos = encode(pos, GROUP_SESSION_ID_TAG, session_id_pos, session_id_length); + std::memcpy(session_id_pos, session_id, session_id_length); + pos = encode(pos, COUNTER_TAG, chain_index); + pos = encode(pos, CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length); +} diff --git a/src/outbound_group_session.c b/src/outbound_group_session.c new file mode 100644 index 0000000..a23f684 --- /dev/null +++ b/src/outbound_group_session.c @@ -0,0 +1,183 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "olm/outbound_group_session.h" + +#include +#include + +#include "olm/base64.h" +#include "olm/cipher.h" +#include "olm/error.h" +#include "olm/megolm.h" +#include "olm/message.h" + +#define OLM_PROTOCOL_VERSION 3 +#define SESSION_ID_RANDOM_BYTES 4 +#define GROUP_SESSION_ID_LENGTH (sizeof(struct timeval) + SESSION_ID_RANDOM_BYTES) + +struct OlmOutboundGroupSession { + /** the Megolm ratchet providing the encryption keys */ + Megolm ratchet; + + /** unique identifier for this session */ + uint8_t session_id[GROUP_SESSION_ID_LENGTH]; + + enum OlmErrorCode last_error; +}; + + +size_t olm_outbound_group_session_size() { + return sizeof(OlmOutboundGroupSession); +} + +OlmOutboundGroupSession * olm_outbound_group_session( + void *memory +) { + OlmOutboundGroupSession *session = memory; + olm_clear_outbound_group_session(session); + return session; +} + +const char *olm_outbound_group_session_last_error( + const OlmOutboundGroupSession *session +) { + return _olm_error_to_string(session->last_error); +} + +size_t olm_clear_outbound_group_session( + OlmOutboundGroupSession *session +) { + memset(session, 0, sizeof(OlmOutboundGroupSession)); + return sizeof(OlmOutboundGroupSession); +} + +size_t olm_init_outbound_group_session_random_length( + const OlmOutboundGroupSession *session +) { + /* we need data to initialize the megolm ratchet, plus some more for the + * session id. + */ + return MEGOLM_RATCHET_LENGTH + SESSION_ID_RANDOM_BYTES; +} + +size_t olm_init_outbound_group_session( + OlmOutboundGroupSession *session, + uint8_t const * random, size_t random_length +) { + if (random_length < olm_init_outbound_group_session_random_length(session)) { + /* Insufficient random data for new session */ + session->last_error = OLM_NOT_ENOUGH_RANDOM; + return (size_t)-1; + } + + megolm_init(&(session->ratchet), random, 0); + random += MEGOLM_RATCHET_LENGTH; + + /* initialise the session id. This just has to be unique. We use the + * current time plus some random data. + */ + gettimeofday((struct timeval *)(session->session_id), NULL); + memcpy((session->session_id) + sizeof(struct timeval), + random, SESSION_ID_RANDOM_BYTES); + + return 0; +} + +static size_t raw_message_length( + OlmOutboundGroupSession *session, + size_t plaintext_length) +{ + size_t ciphertext_length, mac_length; + const struct _olm_cipher *cipher = megolm_cipher(); + + ciphertext_length = cipher->ops->encrypt_ciphertext_length( + cipher, plaintext_length + ); + + mac_length = cipher->ops->mac_length(cipher); + + return _olm_encode_group_message_length( + GROUP_SESSION_ID_LENGTH, session->ratchet.counter, + ciphertext_length, mac_length); +} + +size_t olm_group_encrypt_message_length( + OlmOutboundGroupSession *session, + size_t plaintext_length +) { + size_t message_length = raw_message_length(session, plaintext_length); + return _olm_encode_base64_length(message_length); +} + + +size_t olm_group_encrypt( + OlmOutboundGroupSession *session, + uint8_t const * plaintext, size_t plaintext_length, + uint8_t * message, size_t max_message_length +) { + size_t ciphertext_length; + size_t rawmsglen; + size_t result; + uint8_t *ciphertext_ptr, *message_pos; + const struct _olm_cipher *cipher = megolm_cipher(); + + rawmsglen = raw_message_length(session, plaintext_length); + + if (max_message_length < _olm_encode_base64_length(rawmsglen)) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + ciphertext_length = cipher->ops->encrypt_ciphertext_length( + cipher, + plaintext_length + ); + + /* we construct the message at the end of the buffer, so that + * we have room to base64-encode it once we're done. + */ + message_pos = message + _olm_encode_base64_length(rawmsglen) - rawmsglen; + + /* first we build the message structure, then we encrypt + * the plaintext into it. + */ + _olm_encode_group_message( + OLM_PROTOCOL_VERSION, + session->session_id, GROUP_SESSION_ID_LENGTH, + session->ratchet.counter, + ciphertext_length, + message_pos, + &ciphertext_ptr); + + result = cipher->ops->encrypt( + cipher, + megolm_get_data(&(session->ratchet)), MEGOLM_RATCHET_LENGTH, + plaintext, plaintext_length, + ciphertext_ptr, ciphertext_length, + message_pos, rawmsglen + ); + + if (result == (size_t)-1) { + return result; + } + + megolm_advance(&(session->ratchet)); + + return _olm_encode_base64( + message_pos, rawmsglen, + message + ); +} diff --git a/tests/test_group_session.cpp b/tests/test_group_session.cpp new file mode 100644 index 0000000..9081293 --- /dev/null +++ b/tests/test_group_session.cpp @@ -0,0 +1,56 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "olm/outbound_group_session.h" +#include "unittest.hh" + + +int main() { + +uint8_t random_bytes[] = + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF"; + +{ + TestCase test_case("Group message send/receive"); + + size_t size = olm_outbound_group_session_size(); + void *memory = alloca(size); + OlmOutboundGroupSession *session = olm_outbound_group_session(memory); + + assert_equals((size_t)132, + olm_init_outbound_group_session_random_length(session)); + + size_t res = olm_init_outbound_group_session( + session, random_bytes, sizeof(random_bytes)); + assert_equals((size_t)0, res); + + uint8_t plaintext[] = "Message"; + size_t plaintext_length = sizeof(plaintext) - 1; + + size_t msglen = olm_group_encrypt_message_length( + session, plaintext_length); + + uint8_t *msg = (uint8_t *)alloca(msglen); + res = olm_group_encrypt(session, plaintext, plaintext_length, + msg, msglen); + assert_equals(msglen, res); + + // TODO: decode the message +} + +} diff --git a/tests/test_message.cpp b/tests/test_message.cpp index ff14649..e2385ea 100644 --- a/tests/test_message.cpp +++ b/tests/test_message.cpp @@ -1,4 +1,4 @@ -/* Copyright 2015 OpenMarket Ltd +/* Copyright 2015-2016 OpenMarket Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,4 +62,39 @@ assert_equals(message2, output, 35); } /* Message encode test */ + +{ /* group message encode test */ + + TestCase test_case("Group message encode test"); + + const uint8_t session_id[] = "sessionid"; + size_t session_id_len = 9; + + size_t length = _olm_encode_group_message_length( + session_id_len, 200, 10, 8); + size_t expected_length = 1 + (2+session_id_len) + (1+2) + (2+10) + 8; + assert_equals(expected_length, length); + + uint8_t output[50]; + uint8_t *ciphertext_ptr; + + _olm_encode_group_message( + 3, + session_id, session_id_len, + 200, // counter + 10, // ciphertext length + output, + &ciphertext_ptr + ); + + uint8_t expected[] = + "\x03" + "\x2A\x09sessionid" + "\x10\xc8\x01" + "\x22\x0a"; + + assert_equals(expected, output, sizeof(expected)-1); + assert_equals(output+sizeof(expected)-1, ciphertext_ptr); +} /* group message encode test */ + } -- cgit v1.2.3-70-g09d2 From c058554132a0f97e8e8ae3a402605220f8fdaed4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 17 May 2016 18:53:00 +0100 Subject: Implement pickling/unpickling for outbound group sessions --- include/olm/megolm.h | 15 +++++++ include/olm/outbound_group_session.h | 36 +++++++++++++++++ src/megolm.c | 25 +++++++++++- src/outbound_group_session.c | 77 ++++++++++++++++++++++++++++++++++++ tests/test_group_session.cpp | 46 ++++++++++++++++++--- 5 files changed, 191 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/include/olm/megolm.h b/include/olm/megolm.h index 5cae353..831c6fb 100644 --- a/include/olm/megolm.h +++ b/include/olm/megolm.h @@ -61,6 +61,21 @@ const struct _olm_cipher *megolm_cipher(); */ void megolm_init(Megolm *megolm, uint8_t const *random_data, uint32_t counter); +/** Returns the number of bytes needed to store a megolm */ +size_t megolm_pickle_length(const Megolm *megolm); + +/** + * Pickle the megolm. Returns a pointer to the next free space in the buffer. + */ +uint8_t * megolm_pickle(const Megolm *megolm, uint8_t *pos); + +/** + * Unpickle the megolm. Returns a pointer to the next item in the buffer. + */ +const uint8_t * megolm_unpickle(Megolm *megolm, const uint8_t *pos, + const uint8_t *end); + + /** advance the ratchet by one step */ void megolm_advance(Megolm *megolm); diff --git a/include/olm/outbound_group_session.h b/include/olm/outbound_group_session.h index 6c02370..27991ac 100644 --- a/include/olm/outbound_group_session.h +++ b/include/olm/outbound_group_session.h @@ -48,6 +48,42 @@ size_t olm_clear_outbound_group_session( OlmOutboundGroupSession *session ); +/** Returns the number of bytes needed to store an outbound group session */ +size_t olm_pickle_outbound_group_session_length( + const OlmOutboundGroupSession *session +); + +/** + * Stores a group session as a base64 string. Encrypts the session using the + * supplied key. Returns the length of the session on success. + * + * Returns olm_error() on failure. If the pickle output buffer + * is smaller than olm_pickle_outbound_group_session_length() then + * olm_outbound_group_session_last_error() will be "OUTPUT_BUFFER_TOO_SMALL" + */ +size_t olm_pickle_outbound_group_session( + OlmOutboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +); + +/** + * Loads a group session from a pickled base64 string. Decrypts the session + * using the supplied key. + * + * Returns olm_error() on failure. If the key doesn't match the one used to + * encrypt the account then olm_outbound_group_session_last_error() will be + * "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then + * olm_outbound_group_session_last_error() will be "INVALID_BASE64". The input + * pickled buffer is destroyed + */ +size_t olm_unpickle_outbound_group_session( + OlmOutboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +); + + /** The number of random bytes needed to create an outbound group session */ size_t olm_init_outbound_group_session_random_length( const OlmOutboundGroupSession *session diff --git a/src/megolm.c b/src/megolm.c index 58fe725..7567894 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -20,6 +20,7 @@ #include "olm/cipher.h" #include "olm/crypto.h" +#include "olm/pickle.h" const struct _olm_cipher *megolm_cipher() { static const uint8_t CIPHER_KDF_INFO[] = "MEGOLM_KEYS"; @@ -59,12 +60,32 @@ static void rehash_part( -void megolm_init(Megolm *megolm, uint8_t const *random_data, uint32_t counter) -{ +void megolm_init(Megolm *megolm, uint8_t const *random_data, uint32_t counter) { megolm->counter = counter; memcpy(megolm->data, random_data, MEGOLM_RATCHET_LENGTH); } +size_t megolm_pickle_length(const Megolm *megolm) { + size_t length = 0; + length += _olm_pickle_bytes_length(megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH); + length += _olm_pickle_uint32_length(megolm->counter); + return length; + +} + +uint8_t * megolm_pickle(const Megolm *megolm, uint8_t *pos) { + pos = _olm_pickle_bytes(pos, megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH); + pos = _olm_pickle_uint32(pos, megolm->counter); + return pos; +} + +const uint8_t * megolm_unpickle(Megolm *megolm, const uint8_t *pos, + const uint8_t *end) { + pos = _olm_unpickle_bytes(pos, end, (uint8_t *)(megolm->data), + MEGOLM_RATCHET_LENGTH); + pos = _olm_unpickle_uint32(pos, end, &megolm->counter); + return pos; +} /* simplistic implementation for a single step */ void megolm_advance(Megolm *megolm) { diff --git a/src/outbound_group_session.c b/src/outbound_group_session.c index a23f684..8dc1cd1 100644 --- a/src/outbound_group_session.c +++ b/src/outbound_group_session.c @@ -23,10 +23,13 @@ #include "olm/error.h" #include "olm/megolm.h" #include "olm/message.h" +#include "olm/pickle.h" +#include "olm/pickle_encoding.h" #define OLM_PROTOCOL_VERSION 3 #define SESSION_ID_RANDOM_BYTES 4 #define GROUP_SESSION_ID_LENGTH (sizeof(struct timeval) + SESSION_ID_RANDOM_BYTES) +#define PICKLE_VERSION 1 struct OlmOutboundGroupSession { /** the Megolm ratchet providing the encryption keys */ @@ -64,6 +67,80 @@ size_t olm_clear_outbound_group_session( return sizeof(OlmOutboundGroupSession); } +static size_t raw_pickle_length( + const OlmOutboundGroupSession *session +) { + size_t length = 0; + length += _olm_pickle_uint32_length(PICKLE_VERSION); + length += megolm_pickle_length(&(session->ratchet)); + length += _olm_pickle_bytes_length(session->session_id, + GROUP_SESSION_ID_LENGTH); + return length; +} + +size_t olm_pickle_outbound_group_session_length( + const OlmOutboundGroupSession *session +) { + return _olm_enc_output_length(raw_pickle_length(session)); +} + +size_t olm_pickle_outbound_group_session( + OlmOutboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + size_t raw_length = raw_pickle_length(session); + uint8_t *pos; + + if (pickled_length < _olm_enc_output_length(raw_length)) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + pos = _olm_enc_output_pos(pickled, raw_length); + pos = _olm_pickle_uint32(pos, PICKLE_VERSION); + pos = megolm_pickle(&(session->ratchet), pos); + pos = _olm_pickle_bytes(pos, session->session_id, GROUP_SESSION_ID_LENGTH); + + return _olm_enc_output(key, key_length, pickled, raw_length); +} + +size_t olm_unpickle_outbound_group_session( + OlmOutboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + const uint8_t *pos; + const uint8_t *end; + uint32_t pickle_version; + + size_t raw_length = _olm_enc_input( + key, key_length, pickled, pickled_length, &(session->last_error) + ); + if (raw_length == (size_t)-1) { + return raw_length; + } + + pos = pickled; + end = pos + raw_length; + pos = _olm_unpickle_uint32(pos, end, &pickle_version); + if (pickle_version != PICKLE_VERSION) { + session->last_error = OLM_UNKNOWN_PICKLE_VERSION; + return (size_t)-1; + } + pos = megolm_unpickle(&(session->ratchet), pos, end); + pos = _olm_unpickle_bytes(pos, end, session->session_id, GROUP_SESSION_ID_LENGTH); + + if (end != pos) { + /* We had the wrong number of bytes in the input. */ + session->last_error = OLM_CORRUPTED_PICKLE; + return (size_t)-1; + } + + return pickled_length; +} + + size_t olm_init_outbound_group_session_random_length( const OlmOutboundGroupSession *session ) { diff --git a/tests/test_group_session.cpp b/tests/test_group_session.cpp index 9081293..b9fe1ef 100644 --- a/tests/test_group_session.cpp +++ b/tests/test_group_session.cpp @@ -18,16 +18,50 @@ int main() { -uint8_t random_bytes[] = - "0123456789ABDEF0123456789ABCDEF" - "0123456789ABDEF0123456789ABCDEF" - "0123456789ABDEF0123456789ABCDEF" - "0123456789ABDEF0123456789ABCDEF" - "0123456789ABDEF0123456789ABCDEF"; +{ + + TestCase test_case("Pickle outbound group"); + + size_t size = olm_outbound_group_session_size(); + void *memory = alloca(size); + OlmOutboundGroupSession *session = olm_outbound_group_session(memory); + + size_t pickle_length = olm_pickle_outbound_group_session_length(session); + uint8_t pickle1[pickle_length]; + olm_pickle_outbound_group_session(session, + "secret_key", 10, + pickle1, pickle_length); + uint8_t pickle2[pickle_length]; + memcpy(pickle2, pickle1, pickle_length); + + uint8_t buffer2[size]; + OlmOutboundGroupSession *session2 = olm_outbound_group_session(buffer2); + size_t res = olm_unpickle_outbound_group_session(session2, + "secret_key", 10, + pickle2, pickle_length); + assert_not_equals((size_t)-1, res); + assert_equals(pickle_length, + olm_pickle_outbound_group_session_length(session2)); + olm_pickle_outbound_group_session(session2, + "secret_key", 10, + pickle2, pickle_length); + + assert_equals(pickle1, pickle2, pickle_length); +} + { TestCase test_case("Group message send/receive"); + uint8_t random_bytes[] = + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF"; + + + size_t size = olm_outbound_group_session_size(); void *memory = alloca(size); OlmOutboundGroupSession *session = olm_outbound_group_session(memory); -- cgit v1.2.3-70-g09d2 From 39ad75314b9e28053f568ed6a4109f5d3a9468fe Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 18 May 2016 17:23:09 +0100 Subject: Implement decrypting inbound group messages Includes creation of inbound sessions, etc --- include/olm/error.h | 3 + include/olm/inbound_group_session.h | 153 +++++++++++++++++++++++++++ include/olm/message.h | 24 +++++ include/olm/olm.h | 1 + src/inbound_group_session.c | 199 ++++++++++++++++++++++++++++++++++++ src/message.cpp | 42 ++++++++ tests/test_group_session.cpp | 42 ++++++-- tests/test_message.cpp | 22 ++++ 8 files changed, 480 insertions(+), 6 deletions(-) create mode 100644 include/olm/inbound_group_session.h create mode 100644 src/inbound_group_session.c (limited to 'tests') diff --git a/include/olm/error.h b/include/olm/error.h index 87e019a..3f74992 100644 --- a/include/olm/error.h +++ b/include/olm/error.h @@ -32,6 +32,9 @@ enum OlmErrorCode { OLM_UNKNOWN_PICKLE_VERSION = 9, /*!< The pickled object is too new */ OLM_CORRUPTED_PICKLE = 10, /*!< The pickled object couldn't be decoded */ + OLM_BAD_RATCHET_KEY = 11, + OLM_BAD_CHAIN_INDEX = 12, + /* remember to update the list of string constants in error.c when updating * this list. */ }; diff --git a/include/olm/inbound_group_session.h b/include/olm/inbound_group_session.h new file mode 100644 index 0000000..4cf4ac4 --- /dev/null +++ b/include/olm/inbound_group_session.h @@ -0,0 +1,153 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef OLM_INBOUND_GROUP_SESSION_H_ +#define OLM_INBOUND_GROUP_SESSION_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct OlmInboundGroupSession OlmInboundGroupSession; + +/** get the size of an inbound group session, in bytes. */ +size_t olm_inbound_group_session_size(); + +/** + * Initialise an inbound group session object using the supplied memory + * The supplied memory should be at least olm_inbound_group_session_size() + * bytes. + */ +OlmInboundGroupSession * olm_inbound_group_session( + void *memory +); + +/** + * A null terminated string describing the most recent error to happen to a + * group session */ +const char *olm_inbound_group_session_last_error( + const OlmInboundGroupSession *session +); + +/** Clears the memory used to back this group session */ +size_t olm_clear_inbound_group_session( + OlmInboundGroupSession *session +); + +/** Returns the number of bytes needed to store an inbound group session */ +size_t olm_pickle_inbound_group_session_length( + const OlmInboundGroupSession *session +); + +/** + * Stores a group session as a base64 string. Encrypts the session using the + * supplied key. Returns the length of the session on success. + * + * Returns olm_error() on failure. If the pickle output buffer + * is smaller than olm_pickle_inbound_group_session_length() then + * olm_inbound_group_session_last_error() will be "OUTPUT_BUFFER_TOO_SMALL" + */ +size_t olm_pickle_inbound_group_session( + OlmInboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +); + +/** + * Loads a group session from a pickled base64 string. Decrypts the session + * using the supplied key. + * + * Returns olm_error() on failure. If the key doesn't match the one used to + * encrypt the account then olm_inbound_group_session_last_error() will be + * "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then + * olm_inbound_group_session_last_error() will be "INVALID_BASE64". The input + * pickled buffer is destroyed + */ +size_t olm_unpickle_inbound_group_session( + OlmInboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +); + + +/** + * Start a new inbound group session, based on the parameters supplied. + * + * Returns olm_error() on failure. On failure last_error will be set with an + * error code. The last_error will be: + * + * * OLM_INVALID_BASE64 if the session_key is not valid base64 + * * OLM_BAD_RATCHET_KEY if the session_key is invalid + */ +size_t olm_init_inbound_group_session( + OlmInboundGroupSession *session, + uint32_t message_index, + + /* base64-encoded key */ + uint8_t const * session_key, size_t session_key_length +); + +/** + * Get an upper bound on the number of bytes of plain-text the decrypt method + * will write for a given input message length. The actual size could be + * different due to padding. + * + * The input message buffer is destroyed. + * + * Returns olm_error() on failure. + */ +size_t olm_group_decrypt_max_plaintext_length( + OlmInboundGroupSession *session, + uint8_t * message, size_t message_length +); + +/** + * Decrypt a message. + * + * The input message buffer is destroyed. + * + * Returns the length of the decrypted plain-text, or olm_error() on failure. + * + * On failure last_error will be set with an error code. The last_error will + * be: + * * OLM_OUTPUT_BUFFER_TOO_SMALL if the plain-text buffer is too small + * * OLM_INVALID_BASE64 if the message is not valid base-64 + * * OLM_BAD_MESSAGE_VERSION if the message was encrypted with an unsupported + * version of the protocol + * * OLM_BAD_MESSAGE_FORMAT if the message headers could not be decoded + * * OLM_BAD_MESSAGE_MAC if the message could not be verified + * * OLM_BAD_CHAIN_INDEX if we do not have a ratchet key corresponding to the + * message's index (ie, it was sent before the ratchet key was shared with + * us) + */ +size_t olm_group_decrypt( + OlmInboundGroupSession *session, + + /* input; note that it will be overwritten with the base64-decoded + message. */ + uint8_t * message, size_t message_length, + + /* output */ + uint8_t * plaintext, size_t max_plaintext_length +); + + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* OLM_INBOUND_GROUP_SESSION_H_ */ diff --git a/include/olm/message.h b/include/olm/message.h index 05fb56c..bd7aec3 100644 --- a/include/olm/message.h +++ b/include/olm/message.h @@ -65,6 +65,30 @@ void _olm_encode_group_message( ); +struct _OlmDecodeGroupMessageResults { + uint8_t version; + const uint8_t *session_id; + size_t session_id_length; + uint32_t chain_index; + int has_chain_index; + const uint8_t *ciphertext; + size_t ciphertext_length; +}; + + +/** + * Reads the message headers from the input buffer. + */ +void _olm_decode_group_message( + const uint8_t *input, size_t input_length, + size_t mac_length, + + /* output structure: updated with results */ + struct _OlmDecodeGroupMessageResults *results +); + + + #ifdef __cplusplus } // extern "C" #endif diff --git a/include/olm/olm.h b/include/olm/olm.h index 00e1f63..dbaf71e 100644 --- a/include/olm/olm.h +++ b/include/olm/olm.h @@ -19,6 +19,7 @@ #include #include +#include "olm/inbound_group_session.h" #include "olm/outbound_group_session.h" #ifdef __cplusplus diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c new file mode 100644 index 0000000..4796414 --- /dev/null +++ b/src/inbound_group_session.c @@ -0,0 +1,199 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "olm/inbound_group_session.h" + +#include + +#include "olm/base64.h" +#include "olm/cipher.h" +#include "olm/error.h" +#include "olm/megolm.h" +#include "olm/message.h" + +#define OLM_PROTOCOL_VERSION 3 + +struct OlmInboundGroupSession { + /** our earliest known ratchet value */ + Megolm initial_ratchet; + + /** The most recent ratchet value */ + Megolm latest_ratchet; + + enum OlmErrorCode last_error; +}; + +size_t olm_inbound_group_session_size() { + return sizeof(OlmInboundGroupSession); +} + +OlmInboundGroupSession * olm_inbound_group_session( + void *memory +) { + OlmInboundGroupSession *session = memory; + olm_clear_inbound_group_session(session); + return session; +} + +const char *olm_inbound_group_session_last_error( + const OlmInboundGroupSession *session +) { + return _olm_error_to_string(session->last_error); +} + +size_t olm_clear_inbound_group_session( + OlmInboundGroupSession *session +) { + memset(session, 0, sizeof(OlmInboundGroupSession)); + return sizeof(OlmInboundGroupSession); +} + +size_t olm_init_inbound_group_session( + OlmInboundGroupSession *session, + uint32_t message_index, + const uint8_t * session_key, size_t session_key_length +) { + uint8_t key_buf[MEGOLM_RATCHET_LENGTH]; + size_t raw_length = _olm_decode_base64_length(session_key_length); + + if (raw_length == (size_t)-1) { + session->last_error = OLM_INVALID_BASE64; + return (size_t)-1; + } + + if (raw_length != MEGOLM_RATCHET_LENGTH) { + session->last_error = OLM_BAD_RATCHET_KEY; + return (size_t)-1; + } + + _olm_decode_base64(session_key, session_key_length, key_buf); + megolm_init(&session->initial_ratchet, key_buf, message_index); + megolm_init(&session->latest_ratchet, key_buf, message_index); + memset(key_buf, 0, MEGOLM_RATCHET_LENGTH); + + return 0; +} + +size_t olm_group_decrypt_max_plaintext_length( + OlmInboundGroupSession *session, + uint8_t * message, size_t message_length +) { + size_t r; + const struct _olm_cipher *cipher = megolm_cipher(); + struct _OlmDecodeGroupMessageResults decoded_results; + + r = _olm_decode_base64(message, message_length, message); + if (r == (size_t)-1) { + session->last_error = OLM_INVALID_BASE64; + return r; + } + + _olm_decode_group_message( + message, message_length, + cipher->ops->mac_length(cipher), + &decoded_results); + + if (decoded_results.version != OLM_PROTOCOL_VERSION) { + session->last_error = OLM_BAD_MESSAGE_VERSION; + return (size_t)-1; + } + + if (!decoded_results.ciphertext) { + session->last_error = OLM_BAD_MESSAGE_FORMAT; + return (size_t)-1; + } + + return cipher->ops->decrypt_max_plaintext_length( + cipher, decoded_results.ciphertext_length); +} + + +size_t olm_group_decrypt( + OlmInboundGroupSession *session, + uint8_t * message, size_t message_length, + uint8_t * plaintext, size_t max_plaintext_length +) { + struct _OlmDecodeGroupMessageResults decoded_results; + const struct _olm_cipher *cipher = megolm_cipher(); + size_t max_length, raw_message_length, r; + Megolm *megolm; + Megolm tmp_megolm; + + raw_message_length = _olm_decode_base64(message, message_length, message); + if (raw_message_length == (size_t)-1) { + session->last_error = OLM_INVALID_BASE64; + return (size_t)-1; + } + + _olm_decode_group_message( + message, raw_message_length, + cipher->ops->mac_length(cipher), + &decoded_results); + + if (decoded_results.version != OLM_PROTOCOL_VERSION) { + session->last_error = OLM_BAD_MESSAGE_VERSION; + return (size_t)-1; + } + + if (!decoded_results.has_chain_index || !decoded_results.session_id + || !decoded_results.ciphertext + ) { + session->last_error = OLM_BAD_MESSAGE_FORMAT; + return (size_t)-1; + } + + max_length = cipher->ops->decrypt_max_plaintext_length( + cipher, + decoded_results.ciphertext_length + ); + if (max_plaintext_length < max_length) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + /* pick a megolm instance to use. If we're at or beyond the latest ratchet + * value, use that */ + if ((int32_t)(decoded_results.chain_index - session->latest_ratchet.counter) >= 0) { + megolm = &session->latest_ratchet; + } else if ((int32_t)(decoded_results.chain_index - session->initial_ratchet.counter) < 0) { + /* the counter is before our intial ratchet - we can't decode this. */ + session->last_error = OLM_BAD_CHAIN_INDEX; + return (size_t)-1; + } else { + /* otherwise, start from the initial megolm. Take a copy so that we + * don't overwrite the initial megolm */ + tmp_megolm = session->initial_ratchet; + megolm = &tmp_megolm; + } + + megolm_advance_to(megolm, decoded_results.chain_index); + + /* now try checking the mac, and decrypting */ + r = cipher->ops->decrypt( + cipher, + megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH, + message, raw_message_length, + decoded_results.ciphertext, decoded_results.ciphertext_length, + plaintext, max_plaintext_length + ); + + memset(&tmp_megolm, 0, sizeof(tmp_megolm)); + if (r == (size_t)-1) { + session->last_error = OLM_BAD_MESSAGE_MAC; + return r; + } + + return r; +} diff --git a/src/message.cpp b/src/message.cpp index df0c7bb..ec44262 100644 --- a/src/message.cpp +++ b/src/message.cpp @@ -363,3 +363,45 @@ void _olm_encode_group_message( pos = encode(pos, COUNTER_TAG, chain_index); pos = encode(pos, CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length); } + +void _olm_decode_group_message( + const uint8_t *input, size_t input_length, + size_t mac_length, + struct _OlmDecodeGroupMessageResults *results +) { + std::uint8_t const * pos = input; + std::uint8_t const * end = input + input_length - mac_length; + std::uint8_t const * unknown = nullptr; + + results->session_id = nullptr; + results->session_id_length = 0; + bool has_chain_index = false; + results->chain_index = 0; + results->ciphertext = nullptr; + results->ciphertext_length = 0; + + if (pos == end) return; + if (input_length < mac_length) return; + results->version = *(pos++); + + while (pos != end) { + pos = decode( + pos, end, GROUP_SESSION_ID_TAG, + results->session_id, results->session_id_length + ); + pos = decode( + pos, end, COUNTER_TAG, + results->chain_index, has_chain_index + ); + pos = decode( + pos, end, CIPHERTEXT_TAG, + results->ciphertext, results->ciphertext_length + ); + if (unknown == pos) { + pos = skip_unknown(pos, end); + } + unknown = pos; + } + + results->has_chain_index = (int)has_chain_index; +} diff --git a/tests/test_group_session.cpp b/tests/test_group_session.cpp index b9fe1ef..5bbdc9d 100644 --- a/tests/test_group_session.cpp +++ b/tests/test_group_session.cpp @@ -12,6 +12,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "olm/inbound_group_session.h" #include "olm/outbound_group_session.h" #include "unittest.hh" @@ -19,11 +20,10 @@ int main() { { - TestCase test_case("Pickle outbound group"); size_t size = olm_outbound_group_session_size(); - void *memory = alloca(size); + uint8_t memory[size]; OlmOutboundGroupSession *session = olm_outbound_group_session(memory); size_t pickle_length = olm_pickle_outbound_group_session_length(session); @@ -61,9 +61,9 @@ int main() { "0123456789ABDEF0123456789ABCDEF"; - + /* build the outbound session */ size_t size = olm_outbound_group_session_size(); - void *memory = alloca(size); + uint8_t memory[size]; OlmOutboundGroupSession *session = olm_outbound_group_session(memory); assert_equals((size_t)132, @@ -73,18 +73,48 @@ int main() { session, random_bytes, sizeof(random_bytes)); assert_equals((size_t)0, res); + assert_equals(0U, olm_outbound_group_session_message_index(session)); + size_t session_key_len = olm_outbound_group_session_key_length(session); + uint8_t session_key[session_key_len]; + olm_outbound_group_session_key(session, session_key, session_key_len); + + + /* encode the message */ uint8_t plaintext[] = "Message"; size_t plaintext_length = sizeof(plaintext) - 1; size_t msglen = olm_group_encrypt_message_length( session, plaintext_length); - uint8_t *msg = (uint8_t *)alloca(msglen); + uint8_t msg[msglen]; res = olm_group_encrypt(session, plaintext, plaintext_length, msg, msglen); assert_equals(msglen, res); + assert_equals(1U, olm_outbound_group_session_message_index(session)); + + + /* build the inbound session */ + size = olm_inbound_group_session_size(); + uint8_t inbound_session_memory[size]; + OlmInboundGroupSession *inbound_session = + olm_inbound_group_session(inbound_session_memory); + + res = olm_init_inbound_group_session( + inbound_session, 0U, session_key, session_key_len); + assert_equals((size_t)0, res); - // TODO: decode the message + /* decode the message */ + + /* olm_group_decrypt_max_plaintext_length destroys the input so we have to + copy it. */ + uint8_t msgcopy[msglen]; + memcpy(msgcopy, msg, msglen); + size = olm_group_decrypt_max_plaintext_length(inbound_session, msgcopy, msglen); + uint8_t plaintext_buf[size]; + res = olm_group_decrypt(inbound_session, msg, msglen, + plaintext_buf, size); + assert_equals(plaintext_length, res); + assert_equals(plaintext, plaintext_buf, res); } } diff --git a/tests/test_message.cpp b/tests/test_message.cpp index e2385ea..5fec9e0 100644 --- a/tests/test_message.cpp +++ b/tests/test_message.cpp @@ -97,4 +97,26 @@ assert_equals(message2, output, 35); assert_equals(output+sizeof(expected)-1, ciphertext_ptr); } /* group message encode test */ +{ + TestCase test_case("Group message decode test"); + + struct _OlmDecodeGroupMessageResults results; + std::uint8_t message[] = + "\x03" + "\x2A\x09sessionid" + "\x10\xc8\x01" + "\x22\x0A" "ciphertext" + "hmacsha2"; + + const uint8_t expected_session_id[] = "sessionid"; + + _olm_decode_group_message(message, sizeof(message)-1, 8, &results); + assert_equals(std::uint8_t(3), results.version); + assert_equals(std::size_t(9), results.session_id_length); + assert_equals(expected_session_id, results.session_id, 9); + assert_equals(1, results.has_chain_index); + assert_equals(std::uint32_t(200), results.chain_index); + assert_equals(std::size_t(10), results.ciphertext_length); + assert_equals(ciphertext, results.ciphertext, 10); +} /* group message decode test */ } -- cgit v1.2.3-70-g09d2 From a073d12d8367d27db97751d46b766e8480fd39e4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 18 May 2016 18:03:59 +0100 Subject: Support for pickling inbound group sessions --- src/inbound_group_session.c | 76 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_group_session.cpp | 33 ++++++++++++++++++- 2 files changed, 108 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c index 4796414..34908a9 100644 --- a/src/inbound_group_session.c +++ b/src/inbound_group_session.c @@ -22,8 +22,12 @@ #include "olm/error.h" #include "olm/megolm.h" #include "olm/message.h" +#include "olm/pickle.h" +#include "olm/pickle_encoding.h" + #define OLM_PROTOCOL_VERSION 3 +#define PICKLE_VERSION 1 struct OlmInboundGroupSession { /** our earliest known ratchet value */ @@ -86,6 +90,78 @@ size_t olm_init_inbound_group_session( return 0; } +static size_t raw_pickle_length( + const OlmInboundGroupSession *session +) { + size_t length = 0; + length += _olm_pickle_uint32_length(PICKLE_VERSION); + length += megolm_pickle_length(&session->initial_ratchet); + length += megolm_pickle_length(&session->latest_ratchet); + return length; +} + +size_t olm_pickle_inbound_group_session_length( + const OlmInboundGroupSession *session +) { + return _olm_enc_output_length(raw_pickle_length(session)); +} + +size_t olm_pickle_inbound_group_session( + OlmInboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + size_t raw_length = raw_pickle_length(session); + uint8_t *pos; + + if (pickled_length < _olm_enc_output_length(raw_length)) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + pos = _olm_enc_output_pos(pickled, raw_length); + pos = _olm_pickle_uint32(pos, PICKLE_VERSION); + pos = megolm_pickle(&session->initial_ratchet, pos); + pos = megolm_pickle(&session->latest_ratchet, pos); + + return _olm_enc_output(key, key_length, pickled, raw_length); +} + +size_t olm_unpickle_inbound_group_session( + OlmInboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + const uint8_t *pos; + const uint8_t *end; + uint32_t pickle_version; + + size_t raw_length = _olm_enc_input( + key, key_length, pickled, pickled_length, &(session->last_error) + ); + if (raw_length == (size_t)-1) { + return raw_length; + } + + pos = pickled; + end = pos + raw_length; + pos = _olm_unpickle_uint32(pos, end, &pickle_version); + if (pickle_version != PICKLE_VERSION) { + session->last_error = OLM_UNKNOWN_PICKLE_VERSION; + return (size_t)-1; + } + pos = megolm_unpickle(&session->initial_ratchet, pos, end); + pos = megolm_unpickle(&session->latest_ratchet, pos, end); + + if (end != pos) { + /* We had the wrong number of bytes in the input. */ + session->last_error = OLM_CORRUPTED_PICKLE; + return (size_t)-1; + } + + return pickled_length; +} + size_t olm_group_decrypt_max_plaintext_length( OlmInboundGroupSession *session, uint8_t * message, size_t message_length diff --git a/tests/test_group_session.cpp b/tests/test_group_session.cpp index 5bbdc9d..4a82154 100644 --- a/tests/test_group_session.cpp +++ b/tests/test_group_session.cpp @@ -20,7 +20,7 @@ int main() { { - TestCase test_case("Pickle outbound group"); + TestCase test_case("Pickle outbound group session"); size_t size = olm_outbound_group_session_size(); uint8_t memory[size]; @@ -50,6 +50,37 @@ int main() { } +{ + TestCase test_case("Pickle inbound group session"); + + size_t size = olm_inbound_group_session_size(); + uint8_t memory[size]; + OlmInboundGroupSession *session = olm_inbound_group_session(memory); + + size_t pickle_length = olm_pickle_inbound_group_session_length(session); + uint8_t pickle1[pickle_length]; + olm_pickle_inbound_group_session(session, + "secret_key", 10, + pickle1, pickle_length); + uint8_t pickle2[pickle_length]; + memcpy(pickle2, pickle1, pickle_length); + + uint8_t buffer2[size]; + OlmInboundGroupSession *session2 = olm_inbound_group_session(buffer2); + size_t res = olm_unpickle_inbound_group_session(session2, + "secret_key", 10, + pickle2, pickle_length); + assert_not_equals((size_t)-1, res); + assert_equals(pickle_length, + olm_pickle_inbound_group_session_length(session2)); + olm_pickle_inbound_group_session(session2, + "secret_key", 10, + pickle2, pickle_length); + + assert_equals(pickle1, pickle2, pickle_length); +} + + { TestCase test_case("Group message send/receive"); -- cgit v1.2.3-70-g09d2 From fc4756ddf17f536912a89a4ffcf90a309c236ced Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 19 May 2016 07:53:07 +0100 Subject: Fix up some names, and protobuf tags Make names (of session_key and message_index) more consistent. Use our own protobuf tags rather than trying to piggyback on the one-to-one structure. --- include/olm/error.h | 8 ++++++-- include/olm/inbound_group_session.h | 8 ++++---- include/olm/message.h | 8 ++++---- src/error.c | 2 ++ src/inbound_group_session.c | 12 ++++++------ src/message.cpp | 26 ++++++++++++++------------ tests/test_message.cpp | 16 ++++++++-------- 7 files changed, 44 insertions(+), 36 deletions(-) (limited to 'tests') diff --git a/include/olm/error.h b/include/olm/error.h index 3f74992..98d2cf5 100644 --- a/include/olm/error.h +++ b/include/olm/error.h @@ -32,8 +32,12 @@ enum OlmErrorCode { OLM_UNKNOWN_PICKLE_VERSION = 9, /*!< The pickled object is too new */ OLM_CORRUPTED_PICKLE = 10, /*!< The pickled object couldn't be decoded */ - OLM_BAD_RATCHET_KEY = 11, - OLM_BAD_CHAIN_INDEX = 12, + OLM_BAD_SESSION_KEY = 11, /*!< Attempt to initialise an inbound group + session from an invalid session key */ + OLM_UNKNOWN_MESSAGE_INDEX = 12, /*!< Attempt to decode a message whose + * index is earlier than our earliest + * known session key. + */ /* remember to update the list of string constants in error.c when updating * this list. */ diff --git a/include/olm/inbound_group_session.h b/include/olm/inbound_group_session.h index 4cf4ac4..e24f377 100644 --- a/include/olm/inbound_group_session.h +++ b/include/olm/inbound_group_session.h @@ -91,7 +91,7 @@ size_t olm_unpickle_inbound_group_session( * error code. The last_error will be: * * * OLM_INVALID_BASE64 if the session_key is not valid base64 - * * OLM_BAD_RATCHET_KEY if the session_key is invalid + * * OLM_BAD_SESSION_KEY if the session_key is invalid */ size_t olm_init_inbound_group_session( OlmInboundGroupSession *session, @@ -129,9 +129,9 @@ size_t olm_group_decrypt_max_plaintext_length( * * OLM_BAD_MESSAGE_VERSION if the message was encrypted with an unsupported * version of the protocol * * OLM_BAD_MESSAGE_FORMAT if the message headers could not be decoded - * * OLM_BAD_MESSAGE_MAC if the message could not be verified - * * OLM_BAD_CHAIN_INDEX if we do not have a ratchet key corresponding to the - * message's index (ie, it was sent before the ratchet key was shared with + * * OLM_BAD_MESSAGE_MAC if the message could not be verified + * * OLM_UNKNOWN_MESSAGE_INDEX if we do not have a session key corresponding to the + * message's index (ie, it was sent before the session key was shared with * us) */ size_t olm_group_decrypt( diff --git a/include/olm/message.h b/include/olm/message.h index bd7aec3..cff15f3 100644 --- a/include/olm/message.h +++ b/include/olm/message.h @@ -47,7 +47,7 @@ size_t _olm_encode_group_message_length( * version: version number of the olm protocol * session_id: group session identifier * session_id_length: length of session_id - * chain_index: message index + * message_index: message index * ciphertext_length: length of the ciphertext * output: where to write the output. Should be at least * olm_encode_group_message_length() bytes long. @@ -58,7 +58,7 @@ void _olm_encode_group_message( uint8_t version, const uint8_t *session_id, size_t session_id_length, - uint32_t chain_index, + uint32_t message_index, size_t ciphertext_length, uint8_t *output, uint8_t **ciphertext_ptr @@ -69,8 +69,8 @@ struct _OlmDecodeGroupMessageResults { uint8_t version; const uint8_t *session_id; size_t session_id_length; - uint32_t chain_index; - int has_chain_index; + uint32_t message_index; + int has_message_index; const uint8_t *ciphertext; size_t ciphertext_length; }; diff --git a/src/error.c b/src/error.c index 0690856..bd8a39d 100644 --- a/src/error.c +++ b/src/error.c @@ -27,6 +27,8 @@ static const char * ERRORS[] = { "BAD_ACCOUNT_KEY", "UNKNOWN_PICKLE_VERSION", "CORRUPTED_PICKLE", + "BAD_SESSION_KEY", + "UNKNOWN_MESSAGE_INDEX", }; const char * _olm_error_to_string(enum OlmErrorCode error) diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c index 34908a9..cc6ba5e 100644 --- a/src/inbound_group_session.c +++ b/src/inbound_group_session.c @@ -78,7 +78,7 @@ size_t olm_init_inbound_group_session( } if (raw_length != MEGOLM_RATCHET_LENGTH) { - session->last_error = OLM_BAD_RATCHET_KEY; + session->last_error = OLM_BAD_SESSION_KEY; return (size_t)-1; } @@ -223,7 +223,7 @@ size_t olm_group_decrypt( return (size_t)-1; } - if (!decoded_results.has_chain_index || !decoded_results.session_id + if (!decoded_results.has_message_index || !decoded_results.session_id || !decoded_results.ciphertext ) { session->last_error = OLM_BAD_MESSAGE_FORMAT; @@ -241,11 +241,11 @@ size_t olm_group_decrypt( /* pick a megolm instance to use. If we're at or beyond the latest ratchet * value, use that */ - if ((int32_t)(decoded_results.chain_index - session->latest_ratchet.counter) >= 0) { + if ((int32_t)(decoded_results.message_index - session->latest_ratchet.counter) >= 0) { megolm = &session->latest_ratchet; - } else if ((int32_t)(decoded_results.chain_index - session->initial_ratchet.counter) < 0) { + } else if ((int32_t)(decoded_results.message_index - session->initial_ratchet.counter) < 0) { /* the counter is before our intial ratchet - we can't decode this. */ - session->last_error = OLM_BAD_CHAIN_INDEX; + session->last_error = OLM_UNKNOWN_MESSAGE_INDEX; return (size_t)-1; } else { /* otherwise, start from the initial megolm. Take a copy so that we @@ -254,7 +254,7 @@ size_t olm_group_decrypt( megolm = &tmp_megolm; } - megolm_advance_to(megolm, decoded_results.chain_index); + megolm_advance_to(megolm, decoded_results.message_index); /* now try checking the mac, and decrypting */ r = cipher->ops->decrypt( diff --git a/src/message.cpp b/src/message.cpp index ec44262..ab4300e 100644 --- a/src/message.cpp +++ b/src/message.cpp @@ -328,17 +328,19 @@ void olm::decode_one_time_key_message( -static std::uint8_t const GROUP_SESSION_ID_TAG = 052; +static const std::uint8_t GROUP_SESSION_ID_TAG = 012; +static const std::uint8_t GROUP_MESSAGE_INDEX_TAG = 020; +static const std::uint8_t GROUP_CIPHERTEXT_TAG = 032; size_t _olm_encode_group_message_length( size_t group_session_id_length, - uint32_t chain_index, + uint32_t message_index, size_t ciphertext_length, size_t mac_length ) { size_t length = VERSION_LENGTH; length += 1 + varstring_length(group_session_id_length); - length += 1 + varint_length(chain_index); + length += 1 + varint_length(message_index); length += 1 + varstring_length(ciphertext_length); length += mac_length; return length; @@ -349,7 +351,7 @@ void _olm_encode_group_message( uint8_t version, const uint8_t *session_id, size_t session_id_length, - uint32_t chain_index, + uint32_t message_index, size_t ciphertext_length, uint8_t *output, uint8_t **ciphertext_ptr @@ -360,8 +362,8 @@ void _olm_encode_group_message( *(pos++) = version; pos = encode(pos, GROUP_SESSION_ID_TAG, session_id_pos, session_id_length); std::memcpy(session_id_pos, session_id, session_id_length); - pos = encode(pos, COUNTER_TAG, chain_index); - pos = encode(pos, CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length); + pos = encode(pos, GROUP_MESSAGE_INDEX_TAG, message_index); + pos = encode(pos, GROUP_CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length); } void _olm_decode_group_message( @@ -375,8 +377,8 @@ void _olm_decode_group_message( results->session_id = nullptr; results->session_id_length = 0; - bool has_chain_index = false; - results->chain_index = 0; + bool has_message_index = false; + results->message_index = 0; results->ciphertext = nullptr; results->ciphertext_length = 0; @@ -390,11 +392,11 @@ void _olm_decode_group_message( results->session_id, results->session_id_length ); pos = decode( - pos, end, COUNTER_TAG, - results->chain_index, has_chain_index + pos, end, GROUP_MESSAGE_INDEX_TAG, + results->message_index, has_message_index ); pos = decode( - pos, end, CIPHERTEXT_TAG, + pos, end, GROUP_CIPHERTEXT_TAG, results->ciphertext, results->ciphertext_length ); if (unknown == pos) { @@ -403,5 +405,5 @@ void _olm_decode_group_message( unknown = pos; } - results->has_chain_index = (int)has_chain_index; + results->has_message_index = (int)has_message_index; } diff --git a/tests/test_message.cpp b/tests/test_message.cpp index 5fec9e0..30c10a0 100644 --- a/tests/test_message.cpp +++ b/tests/test_message.cpp @@ -89,9 +89,9 @@ assert_equals(message2, output, 35); uint8_t expected[] = "\x03" - "\x2A\x09sessionid" - "\x10\xc8\x01" - "\x22\x0a"; + "\x0A\x09sessionid" + "\x10\xC8\x01" + "\x1A\x0A"; assert_equals(expected, output, sizeof(expected)-1); assert_equals(output+sizeof(expected)-1, ciphertext_ptr); @@ -103,9 +103,9 @@ assert_equals(message2, output, 35); struct _OlmDecodeGroupMessageResults results; std::uint8_t message[] = "\x03" - "\x2A\x09sessionid" - "\x10\xc8\x01" - "\x22\x0A" "ciphertext" + "\x0A\x09sessionid" + "\x10\xC8\x01" + "\x1A\x0A" "ciphertext" "hmacsha2"; const uint8_t expected_session_id[] = "sessionid"; @@ -114,8 +114,8 @@ assert_equals(message2, output, 35); assert_equals(std::uint8_t(3), results.version); assert_equals(std::size_t(9), results.session_id_length); assert_equals(expected_session_id, results.session_id, 9); - assert_equals(1, results.has_chain_index); - assert_equals(std::uint32_t(200), results.chain_index); + assert_equals(1, results.has_message_index); + assert_equals(std::uint32_t(200), results.message_index); assert_equals(std::size_t(10), results.ciphertext_length); assert_equals(ciphertext, results.ciphertext, 10); } /* group message decode test */ -- cgit v1.2.3-70-g09d2 From 01ea3d4b9a3c6f3e0303c2d421a248715a96af99 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 24 May 2016 17:52:35 +0100 Subject: Fix handling of integer wraparound in megolm.c --- src/megolm.c | 8 ++++++-- tests/test_megolm.cpp | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/src/megolm.c b/src/megolm.c index 6d8af08..a969b36 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -108,8 +108,12 @@ void megolm_advance_to(Megolm *megolm, uint32_t advance_to) { uint32_t mask = (~(uint32_t)0) << shift; int k; - /* how many times to we need to rehash this part? */ - int steps = (advance_to >> shift) - (megolm->counter >> shift); + /* how many times do we need to rehash this part? + * + * '& 0xff' ensures we handle integer wraparound correctly + */ + unsigned int steps = + ((advance_to >> shift) - (megolm->counter >> shift)) & 0xff; if (steps == 0) { continue; diff --git a/tests/test_megolm.cpp b/tests/test_megolm.cpp index 871de36..bf53346 100644 --- a/tests/test_megolm.cpp +++ b/tests/test_megolm.cpp @@ -82,4 +82,20 @@ std::uint8_t random_bytes[] = assert_equals(expected3, megolm_get_data(&mr), MEGOLM_RATCHET_LENGTH); } +{ + TestCase test_case("Megolm::advance wraparound"); + + Megolm mr1, mr2; + + megolm_init(&mr1, random_bytes, 0xffffffffUL); + megolm_advance_to(&mr1, 0x1000000); + assert_equals(0x1000000U, mr1.counter); + + megolm_init(&mr2, random_bytes, 0); + megolm_advance_to(&mr2, 0x2000000); + assert_equals(0x2000000U, mr2.counter); + + assert_equals(megolm_get_data(&mr2), megolm_get_data(&mr1), MEGOLM_RATCHET_LENGTH); +} + } -- cgit v1.2.3-70-g09d2 From 19a7fb5df5ec3445201ce5fbe475a08faf6319fc Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 25 May 2016 15:00:05 +0100 Subject: Fix an integer wrap around bug and add a couple more tests --- src/megolm.c | 6 +++++- tests/test_megolm.cpp | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/src/megolm.c b/src/megolm.c index a969b36..affd3cb 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -116,7 +116,11 @@ void megolm_advance_to(Megolm *megolm, uint32_t advance_to) { ((advance_to >> shift) - (megolm->counter >> shift)) & 0xff; if (steps == 0) { - continue; + if (advance_to < megolm->counter) { + steps = 0x100; + } else { + continue; + } } /* for all but the last step, we can just bump R(j) without regard diff --git a/tests/test_megolm.cpp b/tests/test_megolm.cpp index bf53346..3048fa3 100644 --- a/tests/test_megolm.cpp +++ b/tests/test_megolm.cpp @@ -98,4 +98,37 @@ std::uint8_t random_bytes[] = assert_equals(megolm_get_data(&mr2), megolm_get_data(&mr1), MEGOLM_RATCHET_LENGTH); } +{ + TestCase test_case("Megolm::advance overflow by one"); + + Megolm mr1, mr2; + + megolm_init(&mr1, random_bytes, 0xffffffffUL); + megolm_advance_to(&mr1, 0x0); + assert_equals(0x0U, mr1.counter); + + megolm_init(&mr2, random_bytes, 0xffffffffUL); + megolm_advance(&mr2); + assert_equals(0x0U, mr2.counter); + + assert_equals(megolm_get_data(&mr2), megolm_get_data(&mr1), MEGOLM_RATCHET_LENGTH); +} + +{ + TestCase test_case("Megolm::advance overflow"); + + Megolm mr1, mr2; + + megolm_init(&mr1, random_bytes, 0x1UL); + megolm_advance_to(&mr1, 0x80000000UL); + megolm_advance_to(&mr1, 0x0); + assert_equals(0x0U, mr1.counter); + + megolm_init(&mr2, random_bytes, 0x1UL); + megolm_advance_to(&mr2, 0x0UL); + assert_equals(0x0U, mr2.counter); + + assert_equals(megolm_get_data(&mr2), megolm_get_data(&mr1), MEGOLM_RATCHET_LENGTH); +} + } -- cgit v1.2.3-70-g09d2