From 256bce10fc6c811293cd7d9089ba2d69cec0d59b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 16 May 2016 16:47:49 +0100 Subject: Factor out olm_error_to_string to a separate file I want to be able to use this functionality from elsewhere, so factor it out to its own file. --- include/olm/error.h | 7 +++++++ src/error.c | 39 +++++++++++++++++++++++++++++++++++++++ src/olm.cpp | 38 ++++++-------------------------------- 3 files changed, 52 insertions(+), 32 deletions(-) create mode 100644 src/error.c diff --git a/include/olm/error.h b/include/olm/error.h index a4f373e..460017e 100644 --- a/include/olm/error.h +++ b/include/olm/error.h @@ -31,8 +31,15 @@ enum OlmErrorCode { OLM_BAD_ACCOUNT_KEY = 8, /*!< The supplied account key is invalid */ OLM_UNKNOWN_PICKLE_VERSION = 9, /*!< The pickled object is too new */ OLM_CORRUPTED_PICKLE = 10, /*!< The pickled object couldn't be decoded */ + + /* remember to update the list of string constants in error.c when updating + * this list. */ + }; +/** get a string representation of the given error code. */ +const char * _olm_error_to_string(enum OlmErrorCode error); + #ifdef __cplusplus } // extern "C" #endif diff --git a/src/error.c b/src/error.c new file mode 100644 index 0000000..0690856 --- /dev/null +++ b/src/error.c @@ -0,0 +1,39 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "olm/error.h" + +static const char * ERRORS[] = { + "SUCCESS", + "NOT_ENOUGH_RANDOM", + "OUTPUT_BUFFER_TOO_SMALL", + "BAD_MESSAGE_VERSION", + "BAD_MESSAGE_FORMAT", + "BAD_MESSAGE_MAC", + "BAD_MESSAGE_KEY_ID", + "INVALID_BASE64", + "BAD_ACCOUNT_KEY", + "UNKNOWN_PICKLE_VERSION", + "CORRUPTED_PICKLE", +}; + +const char * _olm_error_to_string(enum OlmErrorCode error) +{ + if (error < (sizeof(ERRORS)/sizeof(ERRORS[0]))) { + return ERRORS[error]; + } else { + return "UNKNOWN_ERROR"; + } +} diff --git a/src/olm.cpp b/src/olm.cpp index fcd033a..babe7eb 100644 --- a/src/olm.cpp +++ b/src/olm.cpp @@ -164,20 +164,6 @@ std::size_t b64_input( return raw_length; } -static const char * ERRORS[11] { - "SUCCESS", - "NOT_ENOUGH_RANDOM", - "OUTPUT_BUFFER_TOO_SMALL", - "BAD_MESSAGE_VERSION", - "BAD_MESSAGE_FORMAT", - "BAD_MESSAGE_MAC", - "BAD_MESSAGE_KEY_ID", - "INVALID_BASE64", - "BAD_ACCOUNT_KEY", - "UNKNOWN_PICKLE_VERSION", - "CORRUPTED_PICKLE", -}; - } // namespace @@ -192,35 +178,23 @@ size_t olm_error() { const char * olm_account_last_error( OlmAccount * account ) { - unsigned error = unsigned(from_c(account)->last_error); - if (error < (sizeof(ERRORS)/sizeof(ERRORS[0]))) { - return ERRORS[error]; - } else { - return "UNKNOWN_ERROR"; - } + auto error = from_c(account)->last_error; + return _olm_error_to_string(error); } const char * olm_session_last_error( OlmSession * session ) { - unsigned error = unsigned(from_c(session)->last_error); - if (error < (sizeof(ERRORS)/sizeof(ERRORS[0]))) { - return ERRORS[error]; - } else { - return "UNKNOWN_ERROR"; - } + auto error = from_c(session)->last_error; + return _olm_error_to_string(error); } const char * olm_utility_last_error( OlmUtility * utility ) { - unsigned error = unsigned(from_c(utility)->last_error); - if (error < (sizeof(ERRORS)/sizeof(ERRORS[0]))) { - return ERRORS[error]; - } else { - return "UNKNOWN_ERROR"; - } + auto error = from_c(utility)->last_error; + return _olm_error_to_string(error); } size_t olm_account_size() { -- cgit v1.2.3 From 42a300fc62a2d10fc14868ac6135d3da3857469f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 17 May 2016 18:48:16 +0100 Subject: Factor out pickle_encoding from olm.cpp We don't need to have all of the top-level pickling functions in olm.cpp; factor out the utilities to support it to pickle_encoding.cpp (and make sure that they have plain-C bindings). --- include/olm/pickle_encoding.h | 76 +++++++++++++++++++++++++++++++++ src/olm.cpp | 97 ++++++------------------------------------- src/pickle_encoding.c | 92 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 181 insertions(+), 84 deletions(-) create mode 100644 include/olm/pickle_encoding.h create mode 100644 src/pickle_encoding.c diff --git a/include/olm/pickle_encoding.h b/include/olm/pickle_encoding.h new file mode 100644 index 0000000..03611df --- /dev/null +++ b/include/olm/pickle_encoding.h @@ -0,0 +1,76 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* functions for encrypting and decrypting pickled representations of objects */ + +#ifndef OLM_PICKLE_ENCODING_H_ +#define OLM_PICKLE_ENCODING_H_ + +#include +#include + +#include "olm/error.h" + +#ifdef __cplusplus +extern "C" { +#endif + + +/** + * Get the number of bytes needed to encode a pickle of the length given + */ +size_t _olm_enc_output_length(size_t raw_length); + +/** + * Get the point in the output buffer that the raw pickle should be written to. + * + * In order that we can use the same buffer for the raw pickle, and the encoded + * pickle, the raw pickle needs to be written at the end of the buffer. (The + * base-64 encoding would otherwise overwrite the end of the input before it + * was encoded.) + */ + uint8_t *_olm_enc_output_pos(uint8_t * output, size_t raw_length); + +/** + * Encrypt and encode the given pickle in-situ. + * + * The raw pickle should have been written to enc_output_pos(pickle, + * raw_length). + * + * Returns the number of bytes in the encoded pickle. + */ +size_t _olm_enc_output( + uint8_t const * key, size_t key_length, + uint8_t *pickle, size_t raw_length +); + +/** + * Decode and decrypt the given pickle in-situ. + * + * Returns the number of bytes in the decoded pickle, or olm_error() on error, + * in which case *last_error will be updated, if last_error is non-NULL. + */ +size_t _olm_enc_input( + uint8_t const * key, size_t key_length, + uint8_t * input, size_t b64_length, + enum OlmErrorCode * last_error +); + + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* OLM_PICKLE_ENCODING_H_ */ diff --git a/src/olm.cpp b/src/olm.cpp index babe7eb..0a4a734 100644 --- a/src/olm.cpp +++ b/src/olm.cpp @@ -16,6 +16,7 @@ #include "olm/session.hh" #include "olm/account.hh" #include "olm/cipher.h" +#include "olm/pickle_encoding.h" #include "olm/utility.hh" #include "olm/base64.hh" #include "olm/memory.hh" @@ -57,78 +58,6 @@ static std::uint8_t const * from_c(void const * bytes) { return reinterpret_cast(bytes); } -static const struct _olm_cipher_aes_sha_256 PICKLE_CIPHER = - OLM_CIPHER_INIT_AES_SHA_256("Pickle"); - -std::size_t enc_output_length( - size_t raw_length -) { - auto *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); - std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length); - length += cipher->ops->mac_length(cipher); - return olm::encode_base64_length(length); -} - - -std::uint8_t * enc_output_pos( - std::uint8_t * output, - size_t raw_length -) { - auto *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); - std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length); - length += cipher->ops->mac_length(cipher); - return output + olm::encode_base64_length(length) - length; -} - -std::size_t enc_output( - std::uint8_t const * key, std::size_t key_length, - std::uint8_t * output, size_t raw_length -) { - auto *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); - std::size_t ciphertext_length = cipher->ops->encrypt_ciphertext_length( - cipher, raw_length - ); - std::size_t length = ciphertext_length + cipher->ops->mac_length(cipher); - std::size_t base64_length = olm::encode_base64_length(length); - std::uint8_t * raw_output = output + base64_length - length; - cipher->ops->encrypt( - cipher, - key, key_length, - raw_output, raw_length, - raw_output, ciphertext_length, - raw_output, length - ); - olm::encode_base64(raw_output, length, output); - return raw_length; -} - -std::size_t enc_input( - std::uint8_t const * key, std::size_t key_length, - std::uint8_t * input, size_t b64_length, - OlmErrorCode & last_error -) { - std::size_t enc_length = olm::decode_base64_length(b64_length); - if (enc_length == std::size_t(-1)) { - last_error = OlmErrorCode::OLM_INVALID_BASE64; - return std::size_t(-1); - } - olm::decode_base64(input, b64_length, input); - auto *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); - std::size_t raw_length = enc_length - cipher->ops->mac_length(cipher); - std::size_t result = cipher->ops->decrypt( - cipher, - key, key_length, - input, enc_length, - input, raw_length, - input, raw_length - ); - if (result == std::size_t(-1)) { - last_error = OlmErrorCode::OLM_BAD_ACCOUNT_KEY; - } - return result; -} - - std::size_t b64_output_length( size_t raw_length ) { @@ -270,14 +199,14 @@ size_t olm_clear_utility( size_t olm_pickle_account_length( OlmAccount * account ) { - return enc_output_length(pickle_length(*from_c(account))); + return _olm_enc_output_length(pickle_length(*from_c(account))); } size_t olm_pickle_session_length( OlmSession * session ) { - return enc_output_length(pickle_length(*from_c(session))); + return _olm_enc_output_length(pickle_length(*from_c(session))); } @@ -288,12 +217,12 @@ size_t olm_pickle_account( ) { olm::Account & object = *from_c(account); std::size_t raw_length = pickle_length(object); - if (pickled_length < enc_output_length(raw_length)) { + if (pickled_length < _olm_enc_output_length(raw_length)) { object.last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; return size_t(-1); } - pickle(enc_output_pos(from_c(pickled), raw_length), object); - return enc_output(from_c(key), key_length, from_c(pickled), raw_length); + pickle(_olm_enc_output_pos(from_c(pickled), raw_length), object); + return _olm_enc_output(from_c(key), key_length, from_c(pickled), raw_length); } @@ -304,12 +233,12 @@ size_t olm_pickle_session( ) { olm::Session & object = *from_c(session); std::size_t raw_length = pickle_length(object); - if (pickled_length < enc_output_length(raw_length)) { + if (pickled_length < _olm_enc_output_length(raw_length)) { object.last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; return size_t(-1); } - pickle(enc_output_pos(from_c(pickled), raw_length), object); - return enc_output(from_c(key), key_length, from_c(pickled), raw_length); + pickle(_olm_enc_output_pos(from_c(pickled), raw_length), object); + return _olm_enc_output(from_c(key), key_length, from_c(pickled), raw_length); } @@ -320,8 +249,8 @@ size_t olm_unpickle_account( ) { olm::Account & object = *from_c(account); std::uint8_t * const pos = from_c(pickled); - std::size_t raw_length = enc_input( - from_c(key), key_length, pos, pickled_length, object.last_error + std::size_t raw_length = _olm_enc_input( + from_c(key), key_length, pos, pickled_length, &object.last_error ); if (raw_length == std::size_t(-1)) { return std::size_t(-1); @@ -348,8 +277,8 @@ size_t olm_unpickle_session( ) { olm::Session & object = *from_c(session); std::uint8_t * const pos = from_c(pickled); - std::size_t raw_length = enc_input( - from_c(key), key_length, pos, pickled_length, object.last_error + std::size_t raw_length = _olm_enc_input( + from_c(key), key_length, pos, pickled_length, &object.last_error ); if (raw_length == std::size_t(-1)) { return std::size_t(-1); diff --git a/src/pickle_encoding.c b/src/pickle_encoding.c new file mode 100644 index 0000000..5d5f8d7 --- /dev/null +++ b/src/pickle_encoding.c @@ -0,0 +1,92 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "olm/pickle_encoding.h" + +#include "olm/base64.h" +#include "olm/cipher.h" +#include "olm/olm.h" + +static const struct _olm_cipher_aes_sha_256 PICKLE_CIPHER = + OLM_CIPHER_INIT_AES_SHA_256("Pickle"); + +size_t _olm_enc_output_length( + size_t raw_length +) { + const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); + size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length); + length += cipher->ops->mac_length(cipher); + return _olm_encode_base64_length(length); +} + +uint8_t * _olm_enc_output_pos( + uint8_t * output, + size_t raw_length +) { + const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); + size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length); + length += cipher->ops->mac_length(cipher); + return output + _olm_encode_base64_length(length) - length; +} + +size_t _olm_enc_output( + uint8_t const * key, size_t key_length, + uint8_t * output, size_t raw_length +) { + const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); + size_t ciphertext_length = cipher->ops->encrypt_ciphertext_length( + cipher, raw_length + ); + size_t length = ciphertext_length + cipher->ops->mac_length(cipher); + size_t base64_length = _olm_encode_base64_length(length); + uint8_t * raw_output = output + base64_length - length; + cipher->ops->encrypt( + cipher, + key, key_length, + raw_output, raw_length, + raw_output, ciphertext_length, + raw_output, length + ); + _olm_encode_base64(raw_output, length, output); + return raw_length; +} + + +size_t _olm_enc_input(uint8_t const * key, size_t key_length, + uint8_t * input, size_t b64_length, + enum OlmErrorCode * last_error +) { + size_t enc_length = _olm_decode_base64_length(b64_length); + if (enc_length == (size_t)-1) { + if (last_error) { + *last_error = OLM_INVALID_BASE64; + } + return (size_t)-1; + } + _olm_decode_base64(input, b64_length, input); + const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); + size_t raw_length = enc_length - cipher->ops->mac_length(cipher); + size_t result = cipher->ops->decrypt( + cipher, + key, key_length, + input, enc_length, + input, raw_length, + input, raw_length + ); + if (result == (size_t)-1 && last_error) { + *last_error = OLM_BAD_ACCOUNT_KEY; + } + return result; +} -- cgit v1.2.3 From 68d3c7bfa9d0d2f8a44edcd2d277c4a516ed6ed5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 27 Apr 2016 17:16:55 +0100 Subject: Implementation of the megolm ratchet --- include/olm/megolm.h | 72 +++++++++++++++++++++++++++ src/megolm.c | 132 ++++++++++++++++++++++++++++++++++++++++++++++++++ tests/test_megolm.cpp | 85 ++++++++++++++++++++++++++++++++ 3 files changed, 289 insertions(+) create mode 100644 include/olm/megolm.h create mode 100644 src/megolm.c create mode 100644 tests/test_megolm.cpp diff --git a/include/olm/megolm.h b/include/olm/megolm.h new file mode 100644 index 0000000..784597e --- /dev/null +++ b/include/olm/megolm.h @@ -0,0 +1,72 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OLM_MEGOLM_H_ +#define OLM_MEGOLM_H_ + +/** + * implementation of the Megolm multi-part ratchet used in group chats. + */ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * number of bytes in each part of the ratchet; this should be the same as + * the length of the hash function used in the HMAC (32 bytes for us, as we + * use HMAC-SHA-256) + */ +#define MEGOLM_RATCHET_PART_LENGTH 32 /* SHA256_OUTPUT_LENGTH */ + +/** + * number of parts in the ratchet; the advance() implementations rely on + * this being 4. + */ +#define MEGOLM_RATCHET_PARTS 4 + +#define MEGOLM_RATCHET_LENGTH (MEGOLM_RATCHET_PARTS * MEGOLM_RATCHET_PART_LENGTH) + +typedef struct Megolm { + uint8_t data[MEGOLM_RATCHET_PARTS][MEGOLM_RATCHET_PART_LENGTH]; + uint32_t counter; +} Megolm; + +/** + * initialize the megolm ratchet. random_data should be at least + * MEGOLM_RATCHET_LENGTH bytes of randomness. + */ +void megolm_init(Megolm *megolm, uint8_t const *random_data, uint32_t counter); + +/** advance the ratchet by one step */ +void megolm_advance(Megolm *megolm); + +/** + * get the key data in the ratchet. The returned data is + * MEGOLM_RATCHET_LENGTH bytes long. + */ +#define megolm_get_data(megolm) ((const uint8_t *)((megolm)->data)) + +/** advance the ratchet to a given count */ +void megolm_advance_to(Megolm *megolm, uint32_t advance_to); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* OLM_MEGOLM_H_ */ diff --git a/src/megolm.c b/src/megolm.c new file mode 100644 index 0000000..36b0cc2 --- /dev/null +++ b/src/megolm.c @@ -0,0 +1,132 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "olm/megolm.h" + +#include + +#include "olm/crypto.h" + +/* the seeds used in the HMAC-SHA-256 functions for each part of the ratchet. + */ +#define HASH_KEY_SEED_LENGTH 1 +static uint8_t HASH_KEY_SEEDS[MEGOLM_RATCHET_PARTS][HASH_KEY_SEED_LENGTH] = { + {0x00}, + {0x01}, + {0x02}, + {0x03} +}; + +static void rehash_part( + uint8_t data[MEGOLM_RATCHET_PARTS][MEGOLM_RATCHET_PART_LENGTH], + int rehash_from_part, int rehash_to_part, + uint32_t old_counter, uint32_t new_counter +) { + _olm_crypto_hmac_sha256( + data[rehash_from_part], + MEGOLM_RATCHET_PART_LENGTH, + HASH_KEY_SEEDS[rehash_to_part], HASH_KEY_SEED_LENGTH, + data[rehash_to_part] + ); +} + + + +void megolm_init(Megolm *megolm, uint8_t const *random_data, uint32_t counter) +{ + megolm->counter = counter; + memcpy(megolm->data, random_data, MEGOLM_RATCHET_LENGTH); +} + + +/* simplistic implementation for a single step */ +void megolm_advance(Megolm *megolm) { + uint32_t mask = 0x00FFFFFF; + int h = 0; + int i; + + megolm->counter++; + + /* figure out how much we need to rekey */ + while (h < (int)MEGOLM_RATCHET_PARTS) { + if (!(megolm->counter & mask)) + break; + h++; + mask >>= 8; + } + + /* now update R(h)...R(3) based on R(h) */ + for (i = MEGOLM_RATCHET_PARTS-1; i >= h; i--) { + rehash_part(megolm->data, h, i, megolm->counter-1, megolm->counter); + } +} + +void megolm_advance_to(Megolm *megolm, uint32_t advance_to) { + int j; + + /* starting with R0, see if we need to update each part of the hash */ + for (j = 0; j < (int)MEGOLM_RATCHET_PARTS; j++) { + int shift = (MEGOLM_RATCHET_PARTS-j-1) * 8; + uint32_t increment = 1 << shift; + uint32_t next_counter; + + /* how many times to we need to rehash this part? */ + int steps = (advance_to >> shift) - (megolm->counter >> shift); + if (steps == 0) { + continue; + } + + megolm->counter = megolm->counter & ~(increment - 1); + next_counter = megolm->counter + increment; + + /* for all but the last step, we can just bump R(j) without regard + * to R(j+1)...R(3). + */ + while (steps > 1) { + rehash_part(megolm->data, j, j, megolm->counter, next_counter); + megolm->counter = next_counter; + steps --; + next_counter = megolm->counter + increment; + } + + /* on the last step (except for j=3), we need to bump at least R(j+1); + * depending on the target count, we may also need to bump R(j+2) and + * R(j+3). + */ + int k; + switch(j) { + case 0: + if (!(advance_to & 0xFFFF00)) { k = 3; } + else if (!(advance_to & 0xFF00)) { k = 2; } + else { k = 1; } + break; + case 1: + if (!(advance_to & 0xFF00)) { k = 3; } + else { k = 2; } + break; + case 2: + case 3: + k = 3; + break; + } + + while (k >= j) { + rehash_part(megolm->data, j, k, megolm->counter, next_counter); + k--; + } + megolm->counter = next_counter; + } +} diff --git a/tests/test_megolm.cpp b/tests/test_megolm.cpp new file mode 100644 index 0000000..871de36 --- /dev/null +++ b/tests/test_megolm.cpp @@ -0,0 +1,85 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "olm/megolm.h" +#include "olm/memory.hh" + +#include "unittest.hh" + + +int main() { + +std::uint8_t random_bytes[] = + "0123456789ABCDEF0123456789ABCDEF" + "0123456789ABCDEF0123456789ABCDEF" + "0123456789ABCDEF0123456789ABCDEF" + "0123456789ABCDEF0123456789ABCDEF"; + +{ + TestCase test_case("Megolm::advance"); + + Megolm mr; + + megolm_init(&mr, random_bytes, 0); + // single-step advance + megolm_advance(&mr); + const std::uint8_t expected1[] = { + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0xba, 0x9c, 0xd9, 0x55, 0x74, 0x1d, 0x1c, 0x16, 0x23, 0x23, 0xec, 0x82, 0x5e, 0x7c, 0x5c, 0xe8, + 0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a + }; + assert_equals(1U, mr.counter); + assert_equals(expected1, megolm_get_data(&mr), MEGOLM_RATCHET_LENGTH); + + // repeat with complex advance + megolm_init(&mr, random_bytes, 0); + megolm_advance_to(&mr, 1); + assert_equals(1U, mr.counter); + assert_equals(expected1, megolm_get_data(&mr), MEGOLM_RATCHET_LENGTH); + + megolm_advance_to(&mr, 0x1000000); + const std::uint8_t expected2[] = { + 0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9, + 0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c, + 0x70, 0x04, 0xc0, 0x1e, 0xe4, 0x9b, 0xd6, 0xef, 0xe0, 0x07, 0x35, 0x25, 0xaf, 0x9b, 0x16, 0x32, + 0xc5, 0xbe, 0x72, 0x6d, 0x12, 0x34, 0x9c, 0xc5, 0xbd, 0x47, 0x2b, 0xdc, 0x2d, 0xf6, 0x54, 0x0f, + 0x31, 0x12, 0x59, 0x11, 0x94, 0xfd, 0xa6, 0x17, 0xe5, 0x68, 0xc6, 0x83, 0x10, 0x1e, 0xae, 0xcd, + 0x7e, 0xdd, 0xd6, 0xde, 0x1f, 0xbc, 0x07, 0x67, 0xae, 0x34, 0xda, 0x1a, 0x09, 0xa5, 0x4e, 0xab, + 0xba, 0x9c, 0xd9, 0x55, 0x74, 0x1d, 0x1c, 0x16, 0x23, 0x23, 0xec, 0x82, 0x5e, 0x7c, 0x5c, 0xe8, + 0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a, + }; + assert_equals(0x1000000U, mr.counter); + assert_equals(expected2, megolm_get_data(&mr), MEGOLM_RATCHET_LENGTH); + + megolm_advance_to(&mr, 0x1041506); + const std::uint8_t expected3[] = { + 0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9, + 0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c, + 0x55, 0x58, 0x8d, 0xf5, 0xb7, 0xa4, 0x88, 0x78, 0x42, 0x89, 0x27, 0x86, 0x81, 0x64, 0x58, 0x9f, + 0x36, 0x63, 0x44, 0x7b, 0x51, 0xed, 0xc3, 0x59, 0x5b, 0x03, 0x6c, 0xa6, 0x04, 0xc4, 0x6d, 0xcd, + 0x5c, 0x54, 0x85, 0x0b, 0xfa, 0x98, 0xa1, 0xfd, 0x79, 0xa9, 0xdf, 0x1c, 0xbe, 0x8f, 0xc5, 0x68, + 0x19, 0x37, 0xd3, 0x0c, 0x85, 0xc8, 0xc3, 0x1f, 0x7b, 0xb8, 0x28, 0x81, 0x6c, 0xf9, 0xff, 0x3b, + 0x95, 0x6c, 0xbf, 0x80, 0x7e, 0x65, 0x12, 0x6a, 0x49, 0x55, 0x8d, 0x45, 0xc8, 0x4a, 0x2e, 0x4c, + 0xd5, 0x6f, 0x03, 0xe2, 0x44, 0x16, 0xb9, 0x8e, 0x1c, 0xfd, 0x97, 0xc2, 0x06, 0xaa, 0x90, 0x7a + }; + assert_equals(0x1041506U, mr.counter); + assert_equals(expected3, megolm_get_data(&mr), MEGOLM_RATCHET_LENGTH); +} + +} -- cgit v1.2.3 From caaed796ad54de3f8ee1e56123973ae9ace346b9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 17 May 2016 11:52:06 +0100 Subject: Implementation of an outbound group session --- include/olm/error.h | 1 - include/olm/megolm.h | 8 ++ include/olm/message.h | 72 ++++++++++++++ include/olm/message.hh | 12 +++ include/olm/olm.h | 2 + include/olm/outbound_group_session.h | 90 +++++++++++++++++ src/megolm.c | 14 +++ src/message.cpp | 40 +++++++- src/outbound_group_session.c | 183 +++++++++++++++++++++++++++++++++++ tests/test_group_session.cpp | 56 +++++++++++ tests/test_message.cpp | 37 ++++++- 11 files changed, 512 insertions(+), 3 deletions(-) create mode 100644 include/olm/message.h create mode 100644 include/olm/outbound_group_session.h create mode 100644 src/outbound_group_session.c create mode 100644 tests/test_group_session.cpp diff --git a/include/olm/error.h b/include/olm/error.h index 460017e..87e019a 100644 --- a/include/olm/error.h +++ b/include/olm/error.h @@ -34,7 +34,6 @@ enum OlmErrorCode { /* remember to update the list of string constants in error.c when updating * this list. */ - }; /** get a string representation of the given error code. */ diff --git a/include/olm/megolm.h b/include/olm/megolm.h index 784597e..5cae353 100644 --- a/include/olm/megolm.h +++ b/include/olm/megolm.h @@ -47,6 +47,14 @@ typedef struct Megolm { uint32_t counter; } Megolm; + +/** + * Get the cipher used in megolm-backed conversations + * + * (AES256 + SHA256, with keys based on an HKDF with info of MEGOLM_KEYS) + */ +const struct _olm_cipher *megolm_cipher(); + /** * initialize the megolm ratchet. random_data should be at least * MEGOLM_RATCHET_LENGTH bytes of randomness. diff --git a/include/olm/message.h b/include/olm/message.h new file mode 100644 index 0000000..05fb56c --- /dev/null +++ b/include/olm/message.h @@ -0,0 +1,72 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * functions for encoding and decoding messages in the Olm protocol. + * + * Some of these functions have only C++ bindings, and are declared in + * message.hh; in time, they should probably be converted to plain C and + * declared here. + */ + +#ifndef OLM_MESSAGE_H_ +#define OLM_MESSAGE_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * The length of the buffer needed to hold a group message. + */ +size_t _olm_encode_group_message_length( + size_t group_session_id_length, + uint32_t chain_index, + size_t ciphertext_length, + size_t mac_length +); + +/** + * Writes the message headers into the output buffer. + * + * version: version number of the olm protocol + * session_id: group session identifier + * session_id_length: length of session_id + * chain_index: message index + * ciphertext_length: length of the ciphertext + * output: where to write the output. Should be at least + * olm_encode_group_message_length() bytes long. + * ciphertext_ptr: returns the address that the ciphertext + * should be written to, followed by the MAC. + */ +void _olm_encode_group_message( + uint8_t version, + const uint8_t *session_id, + size_t session_id_length, + uint32_t chain_index, + size_t ciphertext_length, + uint8_t *output, + uint8_t **ciphertext_ptr +); + + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* OLM_MESSAGE_H_ */ diff --git a/include/olm/message.hh b/include/olm/message.hh index 5ce0a62..bd912d9 100644 --- a/include/olm/message.hh +++ b/include/olm/message.hh @@ -12,6 +12,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + + +/** + * functions for encoding and decoding messages in the Olm protocol. + * + * Some of these functions have plain-C bindings, and are declared in + * message.h; in time, all of the functions declared here should probably be + * converted to plain C and moved to message.h. + */ + +#include "message.h" + #include #include diff --git a/include/olm/olm.h b/include/olm/olm.h index 8abac49..00e1f63 100644 --- a/include/olm/olm.h +++ b/include/olm/olm.h @@ -19,6 +19,8 @@ #include #include +#include "olm/outbound_group_session.h" + #ifdef __cplusplus extern "C" { #endif diff --git a/include/olm/outbound_group_session.h b/include/olm/outbound_group_session.h new file mode 100644 index 0000000..6c02370 --- /dev/null +++ b/include/olm/outbound_group_session.h @@ -0,0 +1,90 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef OLM_OUTBOUND_GROUP_SESSION_H_ +#define OLM_OUTBOUND_GROUP_SESSION_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct OlmOutboundGroupSession OlmOutboundGroupSession; + +/** get the size of an outbound group session, in bytes. */ +size_t olm_outbound_group_session_size(); + +/** + * Initialise an outbound group session object using the supplied memory + * The supplied memory should be at least olm_outbound_group_session_size() + * bytes. + */ +OlmOutboundGroupSession * olm_outbound_group_session( + void *memory +); + +/** + * A null terminated string describing the most recent error to happen to a + * group session */ +const char *olm_outbound_group_session_last_error( + const OlmOutboundGroupSession *session +); + +/** Clears the memory used to back this group session */ +size_t olm_clear_outbound_group_session( + OlmOutboundGroupSession *session +); + +/** The number of random bytes needed to create an outbound group session */ +size_t olm_init_outbound_group_session_random_length( + const OlmOutboundGroupSession *session +); + +/** + * Start a new outbound group session. Returns std::size_t(-1) on failure. On + * failure last_error will be set with an error code. The last_error will be + * NOT_ENOUGH_RANDOM if the number of random bytes was too small. + */ +size_t olm_init_outbound_group_session( + OlmOutboundGroupSession *session, + uint8_t const * random, size_t random_length +); + +/** + * The number of bytes that will be created by encrypting a message + */ +size_t olm_group_encrypt_message_length( + OlmOutboundGroupSession *session, + size_t plaintext_length +); + +/** + * Encrypt some plain-text. Returns the length of the encrypted message or + * std::size_t(-1) on failure. On failure last_error will be set with an + * error code. The last_error will be OUTPUT_BUFFER_TOO_SMALL if the output + * buffer is too small. + */ +size_t olm_group_encrypt( + OlmOutboundGroupSession *session, + uint8_t const * plaintext, size_t plaintext_length, + uint8_t * message, size_t message_length +); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* OLM_OUTBOUND_GROUP_SESSION_H_ */ diff --git a/src/megolm.c b/src/megolm.c index 36b0cc2..58fe725 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -18,8 +18,22 @@ #include +#include "olm/cipher.h" #include "olm/crypto.h" +const struct _olm_cipher *megolm_cipher() { + static const uint8_t CIPHER_KDF_INFO[] = "MEGOLM_KEYS"; + static struct _olm_cipher *cipher; + static struct _olm_cipher_aes_sha_256 OLM_CIPHER; + if (!cipher) { + cipher = _olm_cipher_aes_sha_256_init( + &OLM_CIPHER, + CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) - 1 + ); + } + return cipher; +} + /* the seeds used in the HMAC-SHA-256 functions for each part of the ratchet. */ #define HASH_KEY_SEED_LENGTH 1 diff --git a/src/message.cpp b/src/message.cpp index 3be5234..df0c7bb 100644 --- a/src/message.cpp +++ b/src/message.cpp @@ -1,4 +1,4 @@ -/* Copyright 2015 OpenMarket Ltd +/* Copyright 2015-2016 OpenMarket Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -325,3 +325,41 @@ void olm::decode_one_time_key_message( unknown = pos; } } + + + +static std::uint8_t const GROUP_SESSION_ID_TAG = 052; + +size_t _olm_encode_group_message_length( + size_t group_session_id_length, + uint32_t chain_index, + size_t ciphertext_length, + size_t mac_length +) { + size_t length = VERSION_LENGTH; + length += 1 + varstring_length(group_session_id_length); + length += 1 + varint_length(chain_index); + length += 1 + varstring_length(ciphertext_length); + length += mac_length; + return length; +} + + +void _olm_encode_group_message( + uint8_t version, + const uint8_t *session_id, + size_t session_id_length, + uint32_t chain_index, + size_t ciphertext_length, + uint8_t *output, + uint8_t **ciphertext_ptr +) { + std::uint8_t * pos = output; + std::uint8_t * session_id_pos; + + *(pos++) = version; + pos = encode(pos, GROUP_SESSION_ID_TAG, session_id_pos, session_id_length); + std::memcpy(session_id_pos, session_id, session_id_length); + pos = encode(pos, COUNTER_TAG, chain_index); + pos = encode(pos, CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length); +} diff --git a/src/outbound_group_session.c b/src/outbound_group_session.c new file mode 100644 index 0000000..a23f684 --- /dev/null +++ b/src/outbound_group_session.c @@ -0,0 +1,183 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "olm/outbound_group_session.h" + +#include +#include + +#include "olm/base64.h" +#include "olm/cipher.h" +#include "olm/error.h" +#include "olm/megolm.h" +#include "olm/message.h" + +#define OLM_PROTOCOL_VERSION 3 +#define SESSION_ID_RANDOM_BYTES 4 +#define GROUP_SESSION_ID_LENGTH (sizeof(struct timeval) + SESSION_ID_RANDOM_BYTES) + +struct OlmOutboundGroupSession { + /** the Megolm ratchet providing the encryption keys */ + Megolm ratchet; + + /** unique identifier for this session */ + uint8_t session_id[GROUP_SESSION_ID_LENGTH]; + + enum OlmErrorCode last_error; +}; + + +size_t olm_outbound_group_session_size() { + return sizeof(OlmOutboundGroupSession); +} + +OlmOutboundGroupSession * olm_outbound_group_session( + void *memory +) { + OlmOutboundGroupSession *session = memory; + olm_clear_outbound_group_session(session); + return session; +} + +const char *olm_outbound_group_session_last_error( + const OlmOutboundGroupSession *session +) { + return _olm_error_to_string(session->last_error); +} + +size_t olm_clear_outbound_group_session( + OlmOutboundGroupSession *session +) { + memset(session, 0, sizeof(OlmOutboundGroupSession)); + return sizeof(OlmOutboundGroupSession); +} + +size_t olm_init_outbound_group_session_random_length( + const OlmOutboundGroupSession *session +) { + /* we need data to initialize the megolm ratchet, plus some more for the + * session id. + */ + return MEGOLM_RATCHET_LENGTH + SESSION_ID_RANDOM_BYTES; +} + +size_t olm_init_outbound_group_session( + OlmOutboundGroupSession *session, + uint8_t const * random, size_t random_length +) { + if (random_length < olm_init_outbound_group_session_random_length(session)) { + /* Insufficient random data for new session */ + session->last_error = OLM_NOT_ENOUGH_RANDOM; + return (size_t)-1; + } + + megolm_init(&(session->ratchet), random, 0); + random += MEGOLM_RATCHET_LENGTH; + + /* initialise the session id. This just has to be unique. We use the + * current time plus some random data. + */ + gettimeofday((struct timeval *)(session->session_id), NULL); + memcpy((session->session_id) + sizeof(struct timeval), + random, SESSION_ID_RANDOM_BYTES); + + return 0; +} + +static size_t raw_message_length( + OlmOutboundGroupSession *session, + size_t plaintext_length) +{ + size_t ciphertext_length, mac_length; + const struct _olm_cipher *cipher = megolm_cipher(); + + ciphertext_length = cipher->ops->encrypt_ciphertext_length( + cipher, plaintext_length + ); + + mac_length = cipher->ops->mac_length(cipher); + + return _olm_encode_group_message_length( + GROUP_SESSION_ID_LENGTH, session->ratchet.counter, + ciphertext_length, mac_length); +} + +size_t olm_group_encrypt_message_length( + OlmOutboundGroupSession *session, + size_t plaintext_length +) { + size_t message_length = raw_message_length(session, plaintext_length); + return _olm_encode_base64_length(message_length); +} + + +size_t olm_group_encrypt( + OlmOutboundGroupSession *session, + uint8_t const * plaintext, size_t plaintext_length, + uint8_t * message, size_t max_message_length +) { + size_t ciphertext_length; + size_t rawmsglen; + size_t result; + uint8_t *ciphertext_ptr, *message_pos; + const struct _olm_cipher *cipher = megolm_cipher(); + + rawmsglen = raw_message_length(session, plaintext_length); + + if (max_message_length < _olm_encode_base64_length(rawmsglen)) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + ciphertext_length = cipher->ops->encrypt_ciphertext_length( + cipher, + plaintext_length + ); + + /* we construct the message at the end of the buffer, so that + * we have room to base64-encode it once we're done. + */ + message_pos = message + _olm_encode_base64_length(rawmsglen) - rawmsglen; + + /* first we build the message structure, then we encrypt + * the plaintext into it. + */ + _olm_encode_group_message( + OLM_PROTOCOL_VERSION, + session->session_id, GROUP_SESSION_ID_LENGTH, + session->ratchet.counter, + ciphertext_length, + message_pos, + &ciphertext_ptr); + + result = cipher->ops->encrypt( + cipher, + megolm_get_data(&(session->ratchet)), MEGOLM_RATCHET_LENGTH, + plaintext, plaintext_length, + ciphertext_ptr, ciphertext_length, + message_pos, rawmsglen + ); + + if (result == (size_t)-1) { + return result; + } + + megolm_advance(&(session->ratchet)); + + return _olm_encode_base64( + message_pos, rawmsglen, + message + ); +} diff --git a/tests/test_group_session.cpp b/tests/test_group_session.cpp new file mode 100644 index 0000000..9081293 --- /dev/null +++ b/tests/test_group_session.cpp @@ -0,0 +1,56 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "olm/outbound_group_session.h" +#include "unittest.hh" + + +int main() { + +uint8_t random_bytes[] = + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF"; + +{ + TestCase test_case("Group message send/receive"); + + size_t size = olm_outbound_group_session_size(); + void *memory = alloca(size); + OlmOutboundGroupSession *session = olm_outbound_group_session(memory); + + assert_equals((size_t)132, + olm_init_outbound_group_session_random_length(session)); + + size_t res = olm_init_outbound_group_session( + session, random_bytes, sizeof(random_bytes)); + assert_equals((size_t)0, res); + + uint8_t plaintext[] = "Message"; + size_t plaintext_length = sizeof(plaintext) - 1; + + size_t msglen = olm_group_encrypt_message_length( + session, plaintext_length); + + uint8_t *msg = (uint8_t *)alloca(msglen); + res = olm_group_encrypt(session, plaintext, plaintext_length, + msg, msglen); + assert_equals(msglen, res); + + // TODO: decode the message +} + +} diff --git a/tests/test_message.cpp b/tests/test_message.cpp index ff14649..e2385ea 100644 --- a/tests/test_message.cpp +++ b/tests/test_message.cpp @@ -1,4 +1,4 @@ -/* Copyright 2015 OpenMarket Ltd +/* Copyright 2015-2016 OpenMarket Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,4 +62,39 @@ assert_equals(message2, output, 35); } /* Message encode test */ + +{ /* group message encode test */ + + TestCase test_case("Group message encode test"); + + const uint8_t session_id[] = "sessionid"; + size_t session_id_len = 9; + + size_t length = _olm_encode_group_message_length( + session_id_len, 200, 10, 8); + size_t expected_length = 1 + (2+session_id_len) + (1+2) + (2+10) + 8; + assert_equals(expected_length, length); + + uint8_t output[50]; + uint8_t *ciphertext_ptr; + + _olm_encode_group_message( + 3, + session_id, session_id_len, + 200, // counter + 10, // ciphertext length + output, + &ciphertext_ptr + ); + + uint8_t expected[] = + "\x03" + "\x2A\x09sessionid" + "\x10\xc8\x01" + "\x22\x0a"; + + assert_equals(expected, output, sizeof(expected)-1); + assert_equals(output+sizeof(expected)-1, ciphertext_ptr); +} /* group message encode test */ + } -- cgit v1.2.3 From c058554132a0f97e8e8ae3a402605220f8fdaed4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 17 May 2016 18:53:00 +0100 Subject: Implement pickling/unpickling for outbound group sessions --- include/olm/megolm.h | 15 +++++++ include/olm/outbound_group_session.h | 36 +++++++++++++++++ src/megolm.c | 25 +++++++++++- src/outbound_group_session.c | 77 ++++++++++++++++++++++++++++++++++++ tests/test_group_session.cpp | 46 ++++++++++++++++++--- 5 files changed, 191 insertions(+), 8 deletions(-) diff --git a/include/olm/megolm.h b/include/olm/megolm.h index 5cae353..831c6fb 100644 --- a/include/olm/megolm.h +++ b/include/olm/megolm.h @@ -61,6 +61,21 @@ const struct _olm_cipher *megolm_cipher(); */ void megolm_init(Megolm *megolm, uint8_t const *random_data, uint32_t counter); +/** Returns the number of bytes needed to store a megolm */ +size_t megolm_pickle_length(const Megolm *megolm); + +/** + * Pickle the megolm. Returns a pointer to the next free space in the buffer. + */ +uint8_t * megolm_pickle(const Megolm *megolm, uint8_t *pos); + +/** + * Unpickle the megolm. Returns a pointer to the next item in the buffer. + */ +const uint8_t * megolm_unpickle(Megolm *megolm, const uint8_t *pos, + const uint8_t *end); + + /** advance the ratchet by one step */ void megolm_advance(Megolm *megolm); diff --git a/include/olm/outbound_group_session.h b/include/olm/outbound_group_session.h index 6c02370..27991ac 100644 --- a/include/olm/outbound_group_session.h +++ b/include/olm/outbound_group_session.h @@ -48,6 +48,42 @@ size_t olm_clear_outbound_group_session( OlmOutboundGroupSession *session ); +/** Returns the number of bytes needed to store an outbound group session */ +size_t olm_pickle_outbound_group_session_length( + const OlmOutboundGroupSession *session +); + +/** + * Stores a group session as a base64 string. Encrypts the session using the + * supplied key. Returns the length of the session on success. + * + * Returns olm_error() on failure. If the pickle output buffer + * is smaller than olm_pickle_outbound_group_session_length() then + * olm_outbound_group_session_last_error() will be "OUTPUT_BUFFER_TOO_SMALL" + */ +size_t olm_pickle_outbound_group_session( + OlmOutboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +); + +/** + * Loads a group session from a pickled base64 string. Decrypts the session + * using the supplied key. + * + * Returns olm_error() on failure. If the key doesn't match the one used to + * encrypt the account then olm_outbound_group_session_last_error() will be + * "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then + * olm_outbound_group_session_last_error() will be "INVALID_BASE64". The input + * pickled buffer is destroyed + */ +size_t olm_unpickle_outbound_group_session( + OlmOutboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +); + + /** The number of random bytes needed to create an outbound group session */ size_t olm_init_outbound_group_session_random_length( const OlmOutboundGroupSession *session diff --git a/src/megolm.c b/src/megolm.c index 58fe725..7567894 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -20,6 +20,7 @@ #include "olm/cipher.h" #include "olm/crypto.h" +#include "olm/pickle.h" const struct _olm_cipher *megolm_cipher() { static const uint8_t CIPHER_KDF_INFO[] = "MEGOLM_KEYS"; @@ -59,12 +60,32 @@ static void rehash_part( -void megolm_init(Megolm *megolm, uint8_t const *random_data, uint32_t counter) -{ +void megolm_init(Megolm *megolm, uint8_t const *random_data, uint32_t counter) { megolm->counter = counter; memcpy(megolm->data, random_data, MEGOLM_RATCHET_LENGTH); } +size_t megolm_pickle_length(const Megolm *megolm) { + size_t length = 0; + length += _olm_pickle_bytes_length(megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH); + length += _olm_pickle_uint32_length(megolm->counter); + return length; + +} + +uint8_t * megolm_pickle(const Megolm *megolm, uint8_t *pos) { + pos = _olm_pickle_bytes(pos, megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH); + pos = _olm_pickle_uint32(pos, megolm->counter); + return pos; +} + +const uint8_t * megolm_unpickle(Megolm *megolm, const uint8_t *pos, + const uint8_t *end) { + pos = _olm_unpickle_bytes(pos, end, (uint8_t *)(megolm->data), + MEGOLM_RATCHET_LENGTH); + pos = _olm_unpickle_uint32(pos, end, &megolm->counter); + return pos; +} /* simplistic implementation for a single step */ void megolm_advance(Megolm *megolm) { diff --git a/src/outbound_group_session.c b/src/outbound_group_session.c index a23f684..8dc1cd1 100644 --- a/src/outbound_group_session.c +++ b/src/outbound_group_session.c @@ -23,10 +23,13 @@ #include "olm/error.h" #include "olm/megolm.h" #include "olm/message.h" +#include "olm/pickle.h" +#include "olm/pickle_encoding.h" #define OLM_PROTOCOL_VERSION 3 #define SESSION_ID_RANDOM_BYTES 4 #define GROUP_SESSION_ID_LENGTH (sizeof(struct timeval) + SESSION_ID_RANDOM_BYTES) +#define PICKLE_VERSION 1 struct OlmOutboundGroupSession { /** the Megolm ratchet providing the encryption keys */ @@ -64,6 +67,80 @@ size_t olm_clear_outbound_group_session( return sizeof(OlmOutboundGroupSession); } +static size_t raw_pickle_length( + const OlmOutboundGroupSession *session +) { + size_t length = 0; + length += _olm_pickle_uint32_length(PICKLE_VERSION); + length += megolm_pickle_length(&(session->ratchet)); + length += _olm_pickle_bytes_length(session->session_id, + GROUP_SESSION_ID_LENGTH); + return length; +} + +size_t olm_pickle_outbound_group_session_length( + const OlmOutboundGroupSession *session +) { + return _olm_enc_output_length(raw_pickle_length(session)); +} + +size_t olm_pickle_outbound_group_session( + OlmOutboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + size_t raw_length = raw_pickle_length(session); + uint8_t *pos; + + if (pickled_length < _olm_enc_output_length(raw_length)) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + pos = _olm_enc_output_pos(pickled, raw_length); + pos = _olm_pickle_uint32(pos, PICKLE_VERSION); + pos = megolm_pickle(&(session->ratchet), pos); + pos = _olm_pickle_bytes(pos, session->session_id, GROUP_SESSION_ID_LENGTH); + + return _olm_enc_output(key, key_length, pickled, raw_length); +} + +size_t olm_unpickle_outbound_group_session( + OlmOutboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + const uint8_t *pos; + const uint8_t *end; + uint32_t pickle_version; + + size_t raw_length = _olm_enc_input( + key, key_length, pickled, pickled_length, &(session->last_error) + ); + if (raw_length == (size_t)-1) { + return raw_length; + } + + pos = pickled; + end = pos + raw_length; + pos = _olm_unpickle_uint32(pos, end, &pickle_version); + if (pickle_version != PICKLE_VERSION) { + session->last_error = OLM_UNKNOWN_PICKLE_VERSION; + return (size_t)-1; + } + pos = megolm_unpickle(&(session->ratchet), pos, end); + pos = _olm_unpickle_bytes(pos, end, session->session_id, GROUP_SESSION_ID_LENGTH); + + if (end != pos) { + /* We had the wrong number of bytes in the input. */ + session->last_error = OLM_CORRUPTED_PICKLE; + return (size_t)-1; + } + + return pickled_length; +} + + size_t olm_init_outbound_group_session_random_length( const OlmOutboundGroupSession *session ) { diff --git a/tests/test_group_session.cpp b/tests/test_group_session.cpp index 9081293..b9fe1ef 100644 --- a/tests/test_group_session.cpp +++ b/tests/test_group_session.cpp @@ -18,16 +18,50 @@ int main() { -uint8_t random_bytes[] = - "0123456789ABDEF0123456789ABCDEF" - "0123456789ABDEF0123456789ABCDEF" - "0123456789ABDEF0123456789ABCDEF" - "0123456789ABDEF0123456789ABCDEF" - "0123456789ABDEF0123456789ABCDEF"; +{ + + TestCase test_case("Pickle outbound group"); + + size_t size = olm_outbound_group_session_size(); + void *memory = alloca(size); + OlmOutboundGroupSession *session = olm_outbound_group_session(memory); + + size_t pickle_length = olm_pickle_outbound_group_session_length(session); + uint8_t pickle1[pickle_length]; + olm_pickle_outbound_group_session(session, + "secret_key", 10, + pickle1, pickle_length); + uint8_t pickle2[pickle_length]; + memcpy(pickle2, pickle1, pickle_length); + + uint8_t buffer2[size]; + OlmOutboundGroupSession *session2 = olm_outbound_group_session(buffer2); + size_t res = olm_unpickle_outbound_group_session(session2, + "secret_key", 10, + pickle2, pickle_length); + assert_not_equals((size_t)-1, res); + assert_equals(pickle_length, + olm_pickle_outbound_group_session_length(session2)); + olm_pickle_outbound_group_session(session2, + "secret_key", 10, + pickle2, pickle_length); + + assert_equals(pickle1, pickle2, pickle_length); +} + { TestCase test_case("Group message send/receive"); + uint8_t random_bytes[] = + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF"; + + + size_t size = olm_outbound_group_session_size(); void *memory = alloca(size); OlmOutboundGroupSession *session = olm_outbound_group_session(memory); -- cgit v1.2.3 From e545ad7eaf55ac8b7dc7d37c046c541e35cef542 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 17 May 2016 18:55:39 +0100 Subject: Outbound group session support in the python wrappers --- python/.gitignore | 1 + python/olm/__init__.py | 1 + python/olm/__main__.py | 38 ++++++++++++++++- python/olm/outbound_group_session.py | 83 ++++++++++++++++++++++++++++++++++++ python/test_olm.sh | 8 ++++ 5 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 python/olm/outbound_group_session.py diff --git a/python/.gitignore b/python/.gitignore index 4e9d33a..b8ca4f7 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -1,3 +1,4 @@ *.pyc /*.account /*.session +/*.group_session diff --git a/python/olm/__init__.py b/python/olm/__init__.py index 5520132..8681d12 100644 --- a/python/olm/__init__.py +++ b/python/olm/__init__.py @@ -1,2 +1,3 @@ from .account import Account from .session import Session +from .outbound_group_session import OutboundGroupSession diff --git a/python/olm/__main__.py b/python/olm/__main__.py index d2b0d38..8bb2419 100755 --- a/python/olm/__main__.py +++ b/python/olm/__main__.py @@ -8,7 +8,7 @@ import yaml from . import * -if __name__ == '__main__': +def build_arg_parser(): parser = argparse.ArgumentParser() parser.add_argument("--key", help="Account encryption key", default="") commands = parser.add_subparsers() @@ -206,5 +206,41 @@ if __name__ == '__main__': decrypt.set_defaults(func=do_decrypt) + outbound_group = commands.add_parser("outbound_group", help="Create an outbound group session") + outbound_group.add_argument("session_file", help="Local group session file") + outbound_group.set_defaults(func=do_outbound_group) + + group_encrypt = commands.add_parser("group_encrypt", help="Encrypt a group message") + group_encrypt.add_argument("session_file", help="Local group session file") + group_encrypt.add_argument("plaintext_file", help="Plaintext", + type=argparse.FileType('rb'), default=sys.stdin) + group_encrypt.add_argument("message_file", help="Message", + type=argparse.FileType('wb'), default=sys.stdout) + group_encrypt.set_defaults(func=do_group_encrypt) + + return parser + +def do_outbound_group(args): + if os.path.exists(args.session_file): + sys.stderr.write("Session %r file already exists" % ( + args.session_file, + )) + sys.exit(1) + session = OutboundGroupSession() + with open(args.session_file, "wb") as f: + f.write(session.pickle(args.key)) + +def do_group_encrypt(args): + session = OutboundGroupSession() + with open(args.session_file, "rb") as f: + session.unpickle(args.key, f.read()) + plaintext = args.plaintext_file.read() + message = session.encrypt(plaintext) + with open(args.session_file, "wb") as f: + f.write(session.pickle(args.key)) + args.message_file.write(message) + +if __name__ == '__main__': + parser = build_arg_parser() args = parser.parse_args() args.func(args) diff --git a/python/olm/outbound_group_session.py b/python/olm/outbound_group_session.py new file mode 100644 index 0000000..6182647 --- /dev/null +++ b/python/olm/outbound_group_session.py @@ -0,0 +1,83 @@ +import json + +from ._base import * + +lib.olm_outbound_group_session_size.argtypes = [] +lib.olm_outbound_group_session_size.restype = c_size_t + +lib.olm_outbound_group_session.argtypes = [c_void_p] +lib.olm_outbound_group_session.restype = c_void_p + +lib.olm_outbound_group_session_last_error.argtypes = [c_void_p] +lib.olm_outbound_group_session_last_error.restype = c_char_p + +def outbound_group_session_errcheck(res, func, args): + if res == ERR: + raise OlmError("%s: %s" % ( + func.__name__, lib.olm_outbound_group_session_last_error(args[0]) + )) + return res + + +def outbound_group_session_function(func, *types): + func.argtypes = (c_void_p,) + types + func.restypes = c_size_t + func.errcheck = outbound_group_session_errcheck + + +outbound_group_session_function( + lib.olm_pickle_outbound_group_session, c_void_p, c_size_t, c_void_p, c_size_t +) +outbound_group_session_function( + lib.olm_unpickle_outbound_group_session, c_void_p, c_size_t, c_void_p, c_size_t +) + +outbound_group_session_function(lib.olm_init_outbound_group_session_random_length) +outbound_group_session_function(lib.olm_init_outbound_group_session, c_void_p, c_size_t) +outbound_group_session_function(lib.olm_group_encrypt_message_length, c_size_t) +outbound_group_session_function(lib.olm_group_encrypt, + c_void_p, c_size_t, # Plaintext + c_void_p, c_size_t, # Message +) + + +class OutboundGroupSession(object): + def __init__(self): + self.buf = create_string_buffer(lib.olm_outbound_group_session_size()) + self.ptr = lib.olm_outbound_group_session(self.buf) + + random_length = lib.olm_init_outbound_group_session_random_length(self.ptr) + random = read_random(random_length) + random_buffer = create_string_buffer(random) + lib.olm_init_outbound_group_session(self.ptr, random_buffer, random_length) + + def pickle(self, key): + key_buffer = create_string_buffer(key) + pickle_length = lib.olm_pickle_outbound_group_session_length(self.ptr) + pickle_buffer = create_string_buffer(pickle_length) + lib.olm_pickle_outbound_group_session( + self.ptr, key_buffer, len(key), pickle_buffer, pickle_length + ) + return pickle_buffer.raw + + def unpickle(self, key, pickle): + key_buffer = create_string_buffer(key) + pickle_buffer = create_string_buffer(pickle) + lib.olm_unpickle_outbound_group_session( + self.ptr, key_buffer, len(key), pickle_buffer, len(pickle) + ) + + def encrypt(self, plaintext): + message_length = lib.olm_group_encrypt_message_length( + self.ptr, len(plaintext) + ) + message_buffer = create_string_buffer(message_length) + + plaintext_buffer = create_string_buffer(plaintext) + + lib.olm_group_encrypt( + self.ptr, + plaintext_buffer, len(plaintext), + message_buffer, message_length, + ) + return message_buffer.raw diff --git a/python/test_olm.sh b/python/test_olm.sh index b575cbf..0ea1623 100755 --- a/python/test_olm.sh +++ b/python/test_olm.sh @@ -4,11 +4,13 @@ OLM="python -m olm" ALICE_ACCOUNT=alice.account ALICE_SESSION=alice.session +ALICE_GROUP_SESSION=alice.group_session BOB_ACCOUNT=bob.account BOB_SESSION=bob.session rm $ALICE_ACCOUNT $BOB_ACCOUNT rm $ALICE_SESSION $BOB_SESSION +rm $ALICE_GROUP_SESSION $OLM create_account $ALICE_ACCOUNT $OLM create_account $BOB_ACCOUNT @@ -20,3 +22,9 @@ BOB_ONE_TIME_KEY="$($OLM keys --json $BOB_ACCOUNT | jq -r '.one_time_keys.curve2 $OLM outbound $ALICE_ACCOUNT $ALICE_SESSION "$BOB_IDENTITY_KEY" "$BOB_ONE_TIME_KEY" echo "Hello world" | $OLM encrypt $ALICE_SESSION - - | $OLM inbound $BOB_ACCOUNT $BOB_SESSION - - + + +### group sessions + +$OLM outbound_group $ALICE_GROUP_SESSION +echo "Hello world" | $OLM group_encrypt $ALICE_GROUP_SESSION - - -- cgit v1.2.3 From 8b1514c0a653ccc3f49db70131d7d4f7524f1f9b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 18 May 2016 17:20:06 +0100 Subject: Implement functions to get the state of outbound session We need to be able to inspect an outbound session so that we can tell our peer how to set up an inbound session. --- include/olm/outbound_group_session.h | 59 ++++++++++++++++++++++++++++++++++-- python/olm/outbound_group_session.py | 24 +++++++++++++++ src/outbound_group_session.c | 49 ++++++++++++++++++++++++++++-- 3 files changed, 128 insertions(+), 4 deletions(-) diff --git a/include/olm/outbound_group_session.h b/include/olm/outbound_group_session.h index 27991ac..90859e9 100644 --- a/include/olm/outbound_group_session.h +++ b/include/olm/outbound_group_session.h @@ -90,7 +90,7 @@ size_t olm_init_outbound_group_session_random_length( ); /** - * Start a new outbound group session. Returns std::size_t(-1) on failure. On + * Start a new outbound group session. Returns olm_error() on failure. On * failure last_error will be set with an error code. The last_error will be * NOT_ENOUGH_RANDOM if the number of random bytes was too small. */ @@ -109,7 +109,7 @@ size_t olm_group_encrypt_message_length( /** * Encrypt some plain-text. Returns the length of the encrypted message or - * std::size_t(-1) on failure. On failure last_error will be set with an + * olm_error() on failure. On failure last_error will be set with an * error code. The last_error will be OUTPUT_BUFFER_TOO_SMALL if the output * buffer is too small. */ @@ -119,6 +119,61 @@ size_t olm_group_encrypt( uint8_t * message, size_t message_length ); + +/** + * Get the number of bytes returned by olm_outbound_group_session_id() + */ +size_t olm_outbound_group_session_id_length( + const OlmOutboundGroupSession *session +); + +/** + * Get a base64-encoded identifier for this session. + * + * Returns the length of the session id on success or olm_error() on + * failure. On failure last_error will be set with an error code. The + * last_error will be OUTPUT_BUFFER_TOO_SMALL if the id buffer was too + * small. + */ +size_t olm_outbound_group_session_id( + OlmOutboundGroupSession *session, + uint8_t * id, size_t id_length +); + +/** + * Get the current message index for this session. + * + * Each message is sent with an increasing index; this returns the index for + * the next message. + */ +uint32_t olm_outbound_group_session_message_index( + OlmOutboundGroupSession *session +); + +/** + * Get the number of bytes returned by olm_outbound_group_session_key() + */ +size_t olm_outbound_group_session_key_length( + const OlmOutboundGroupSession *session +); + +/** + * Get the base64-encoded current ratchet key for this session. + * + * Each message is sent with a diffent ratchet key. This function returns the + * ratchet key that will be used for the next message. + * + * Returns the length of the ratchet key on success or olm_error() on + * failure. On failure last_error will be set with an error code. The + * last_error will be OUTPUT_BUFFER_TOO_SMALL if the buffer was too small. + */ +size_t olm_outbound_group_session_key( + OlmOutboundGroupSession *session, + uint8_t * key, size_t key_length +); + + + #ifdef __cplusplus } // extern "C" #endif diff --git a/python/olm/outbound_group_session.py b/python/olm/outbound_group_session.py index 6182647..56f0962 100644 --- a/python/olm/outbound_group_session.py +++ b/python/olm/outbound_group_session.py @@ -34,12 +34,21 @@ outbound_group_session_function( outbound_group_session_function(lib.olm_init_outbound_group_session_random_length) outbound_group_session_function(lib.olm_init_outbound_group_session, c_void_p, c_size_t) + +lib.olm_outbound_group_session_message_index.argtypes = [c_void_p] +lib.olm_outbound_group_session_message_index.restype = c_uint32 + outbound_group_session_function(lib.olm_group_encrypt_message_length, c_size_t) outbound_group_session_function(lib.olm_group_encrypt, c_void_p, c_size_t, # Plaintext c_void_p, c_size_t, # Message ) +outbound_group_session_function(lib.olm_outbound_group_session_id_length) +outbound_group_session_function(lib.olm_outbound_group_session_id, c_void_p, c_size_t) +outbound_group_session_function(lib.olm_outbound_group_session_key_length) +outbound_group_session_function(lib.olm_outbound_group_session_key, c_void_p, c_size_t) + class OutboundGroupSession(object): def __init__(self): @@ -81,3 +90,18 @@ class OutboundGroupSession(object): message_buffer, message_length, ) return message_buffer.raw + + def session_id(self): + id_length = lib.olm_outbound_group_session_id_length(self.ptr) + id_buffer = create_string_buffer(id_length) + lib.olm_outbound_group_session_id(self.ptr, id_buffer, id_length); + return id_buffer.raw + + def message_index(self): + return lib.olm_outbound_group_session_message_index(self.ptr) + + def session_key(self): + key_length = lib.olm_outbound_group_session_key_length(self.ptr) + key_buffer = create_string_buffer(key_length) + lib.olm_outbound_group_session_key(self.ptr, key_buffer, key_length); + return key_buffer.raw diff --git a/src/outbound_group_session.c b/src/outbound_group_session.c index 8dc1cd1..fadf949 100644 --- a/src/outbound_group_session.c +++ b/src/outbound_group_session.c @@ -254,7 +254,52 @@ size_t olm_group_encrypt( megolm_advance(&(session->ratchet)); return _olm_encode_base64( - message_pos, rawmsglen, - message + message_pos, rawmsglen, message + ); +} + + +size_t olm_outbound_group_session_id_length( + const OlmOutboundGroupSession *session +) { + return _olm_encode_base64_length(GROUP_SESSION_ID_LENGTH); +} + +size_t olm_outbound_group_session_id( + OlmOutboundGroupSession *session, + uint8_t * id, size_t id_length +) { + if (id_length < olm_outbound_group_session_id_length(session)) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + return _olm_encode_base64(session->session_id, GROUP_SESSION_ID_LENGTH, id); +} + +uint32_t olm_outbound_group_session_message_index( + OlmOutboundGroupSession *session +) { + return session->ratchet.counter; +} + +size_t olm_outbound_group_session_key_length( + const OlmOutboundGroupSession *session +) { + return _olm_encode_base64_length(MEGOLM_RATCHET_LENGTH); +} + +size_t olm_outbound_group_session_key( + OlmOutboundGroupSession *session, + uint8_t * key, size_t key_length +) { + if (key_length < olm_outbound_group_session_key_length(session)) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + return _olm_encode_base64( + megolm_get_data(&session->ratchet), + MEGOLM_RATCHET_LENGTH, key ); } -- cgit v1.2.3 From 39ad75314b9e28053f568ed6a4109f5d3a9468fe Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 18 May 2016 17:23:09 +0100 Subject: Implement decrypting inbound group messages Includes creation of inbound sessions, etc --- include/olm/error.h | 3 + include/olm/inbound_group_session.h | 153 +++++++++++++++++++++++++++ include/olm/message.h | 24 +++++ include/olm/olm.h | 1 + src/inbound_group_session.c | 199 ++++++++++++++++++++++++++++++++++++ src/message.cpp | 42 ++++++++ tests/test_group_session.cpp | 42 ++++++-- tests/test_message.cpp | 22 ++++ 8 files changed, 480 insertions(+), 6 deletions(-) create mode 100644 include/olm/inbound_group_session.h create mode 100644 src/inbound_group_session.c diff --git a/include/olm/error.h b/include/olm/error.h index 87e019a..3f74992 100644 --- a/include/olm/error.h +++ b/include/olm/error.h @@ -32,6 +32,9 @@ enum OlmErrorCode { OLM_UNKNOWN_PICKLE_VERSION = 9, /*!< The pickled object is too new */ OLM_CORRUPTED_PICKLE = 10, /*!< The pickled object couldn't be decoded */ + OLM_BAD_RATCHET_KEY = 11, + OLM_BAD_CHAIN_INDEX = 12, + /* remember to update the list of string constants in error.c when updating * this list. */ }; diff --git a/include/olm/inbound_group_session.h b/include/olm/inbound_group_session.h new file mode 100644 index 0000000..4cf4ac4 --- /dev/null +++ b/include/olm/inbound_group_session.h @@ -0,0 +1,153 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef OLM_INBOUND_GROUP_SESSION_H_ +#define OLM_INBOUND_GROUP_SESSION_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct OlmInboundGroupSession OlmInboundGroupSession; + +/** get the size of an inbound group session, in bytes. */ +size_t olm_inbound_group_session_size(); + +/** + * Initialise an inbound group session object using the supplied memory + * The supplied memory should be at least olm_inbound_group_session_size() + * bytes. + */ +OlmInboundGroupSession * olm_inbound_group_session( + void *memory +); + +/** + * A null terminated string describing the most recent error to happen to a + * group session */ +const char *olm_inbound_group_session_last_error( + const OlmInboundGroupSession *session +); + +/** Clears the memory used to back this group session */ +size_t olm_clear_inbound_group_session( + OlmInboundGroupSession *session +); + +/** Returns the number of bytes needed to store an inbound group session */ +size_t olm_pickle_inbound_group_session_length( + const OlmInboundGroupSession *session +); + +/** + * Stores a group session as a base64 string. Encrypts the session using the + * supplied key. Returns the length of the session on success. + * + * Returns olm_error() on failure. If the pickle output buffer + * is smaller than olm_pickle_inbound_group_session_length() then + * olm_inbound_group_session_last_error() will be "OUTPUT_BUFFER_TOO_SMALL" + */ +size_t olm_pickle_inbound_group_session( + OlmInboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +); + +/** + * Loads a group session from a pickled base64 string. Decrypts the session + * using the supplied key. + * + * Returns olm_error() on failure. If the key doesn't match the one used to + * encrypt the account then olm_inbound_group_session_last_error() will be + * "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then + * olm_inbound_group_session_last_error() will be "INVALID_BASE64". The input + * pickled buffer is destroyed + */ +size_t olm_unpickle_inbound_group_session( + OlmInboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +); + + +/** + * Start a new inbound group session, based on the parameters supplied. + * + * Returns olm_error() on failure. On failure last_error will be set with an + * error code. The last_error will be: + * + * * OLM_INVALID_BASE64 if the session_key is not valid base64 + * * OLM_BAD_RATCHET_KEY if the session_key is invalid + */ +size_t olm_init_inbound_group_session( + OlmInboundGroupSession *session, + uint32_t message_index, + + /* base64-encoded key */ + uint8_t const * session_key, size_t session_key_length +); + +/** + * Get an upper bound on the number of bytes of plain-text the decrypt method + * will write for a given input message length. The actual size could be + * different due to padding. + * + * The input message buffer is destroyed. + * + * Returns olm_error() on failure. + */ +size_t olm_group_decrypt_max_plaintext_length( + OlmInboundGroupSession *session, + uint8_t * message, size_t message_length +); + +/** + * Decrypt a message. + * + * The input message buffer is destroyed. + * + * Returns the length of the decrypted plain-text, or olm_error() on failure. + * + * On failure last_error will be set with an error code. The last_error will + * be: + * * OLM_OUTPUT_BUFFER_TOO_SMALL if the plain-text buffer is too small + * * OLM_INVALID_BASE64 if the message is not valid base-64 + * * OLM_BAD_MESSAGE_VERSION if the message was encrypted with an unsupported + * version of the protocol + * * OLM_BAD_MESSAGE_FORMAT if the message headers could not be decoded + * * OLM_BAD_MESSAGE_MAC if the message could not be verified + * * OLM_BAD_CHAIN_INDEX if we do not have a ratchet key corresponding to the + * message's index (ie, it was sent before the ratchet key was shared with + * us) + */ +size_t olm_group_decrypt( + OlmInboundGroupSession *session, + + /* input; note that it will be overwritten with the base64-decoded + message. */ + uint8_t * message, size_t message_length, + + /* output */ + uint8_t * plaintext, size_t max_plaintext_length +); + + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* OLM_INBOUND_GROUP_SESSION_H_ */ diff --git a/include/olm/message.h b/include/olm/message.h index 05fb56c..bd7aec3 100644 --- a/include/olm/message.h +++ b/include/olm/message.h @@ -65,6 +65,30 @@ void _olm_encode_group_message( ); +struct _OlmDecodeGroupMessageResults { + uint8_t version; + const uint8_t *session_id; + size_t session_id_length; + uint32_t chain_index; + int has_chain_index; + const uint8_t *ciphertext; + size_t ciphertext_length; +}; + + +/** + * Reads the message headers from the input buffer. + */ +void _olm_decode_group_message( + const uint8_t *input, size_t input_length, + size_t mac_length, + + /* output structure: updated with results */ + struct _OlmDecodeGroupMessageResults *results +); + + + #ifdef __cplusplus } // extern "C" #endif diff --git a/include/olm/olm.h b/include/olm/olm.h index 00e1f63..dbaf71e 100644 --- a/include/olm/olm.h +++ b/include/olm/olm.h @@ -19,6 +19,7 @@ #include #include +#include "olm/inbound_group_session.h" #include "olm/outbound_group_session.h" #ifdef __cplusplus diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c new file mode 100644 index 0000000..4796414 --- /dev/null +++ b/src/inbound_group_session.c @@ -0,0 +1,199 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "olm/inbound_group_session.h" + +#include + +#include "olm/base64.h" +#include "olm/cipher.h" +#include "olm/error.h" +#include "olm/megolm.h" +#include "olm/message.h" + +#define OLM_PROTOCOL_VERSION 3 + +struct OlmInboundGroupSession { + /** our earliest known ratchet value */ + Megolm initial_ratchet; + + /** The most recent ratchet value */ + Megolm latest_ratchet; + + enum OlmErrorCode last_error; +}; + +size_t olm_inbound_group_session_size() { + return sizeof(OlmInboundGroupSession); +} + +OlmInboundGroupSession * olm_inbound_group_session( + void *memory +) { + OlmInboundGroupSession *session = memory; + olm_clear_inbound_group_session(session); + return session; +} + +const char *olm_inbound_group_session_last_error( + const OlmInboundGroupSession *session +) { + return _olm_error_to_string(session->last_error); +} + +size_t olm_clear_inbound_group_session( + OlmInboundGroupSession *session +) { + memset(session, 0, sizeof(OlmInboundGroupSession)); + return sizeof(OlmInboundGroupSession); +} + +size_t olm_init_inbound_group_session( + OlmInboundGroupSession *session, + uint32_t message_index, + const uint8_t * session_key, size_t session_key_length +) { + uint8_t key_buf[MEGOLM_RATCHET_LENGTH]; + size_t raw_length = _olm_decode_base64_length(session_key_length); + + if (raw_length == (size_t)-1) { + session->last_error = OLM_INVALID_BASE64; + return (size_t)-1; + } + + if (raw_length != MEGOLM_RATCHET_LENGTH) { + session->last_error = OLM_BAD_RATCHET_KEY; + return (size_t)-1; + } + + _olm_decode_base64(session_key, session_key_length, key_buf); + megolm_init(&session->initial_ratchet, key_buf, message_index); + megolm_init(&session->latest_ratchet, key_buf, message_index); + memset(key_buf, 0, MEGOLM_RATCHET_LENGTH); + + return 0; +} + +size_t olm_group_decrypt_max_plaintext_length( + OlmInboundGroupSession *session, + uint8_t * message, size_t message_length +) { + size_t r; + const struct _olm_cipher *cipher = megolm_cipher(); + struct _OlmDecodeGroupMessageResults decoded_results; + + r = _olm_decode_base64(message, message_length, message); + if (r == (size_t)-1) { + session->last_error = OLM_INVALID_BASE64; + return r; + } + + _olm_decode_group_message( + message, message_length, + cipher->ops->mac_length(cipher), + &decoded_results); + + if (decoded_results.version != OLM_PROTOCOL_VERSION) { + session->last_error = OLM_BAD_MESSAGE_VERSION; + return (size_t)-1; + } + + if (!decoded_results.ciphertext) { + session->last_error = OLM_BAD_MESSAGE_FORMAT; + return (size_t)-1; + } + + return cipher->ops->decrypt_max_plaintext_length( + cipher, decoded_results.ciphertext_length); +} + + +size_t olm_group_decrypt( + OlmInboundGroupSession *session, + uint8_t * message, size_t message_length, + uint8_t * plaintext, size_t max_plaintext_length +) { + struct _OlmDecodeGroupMessageResults decoded_results; + const struct _olm_cipher *cipher = megolm_cipher(); + size_t max_length, raw_message_length, r; + Megolm *megolm; + Megolm tmp_megolm; + + raw_message_length = _olm_decode_base64(message, message_length, message); + if (raw_message_length == (size_t)-1) { + session->last_error = OLM_INVALID_BASE64; + return (size_t)-1; + } + + _olm_decode_group_message( + message, raw_message_length, + cipher->ops->mac_length(cipher), + &decoded_results); + + if (decoded_results.version != OLM_PROTOCOL_VERSION) { + session->last_error = OLM_BAD_MESSAGE_VERSION; + return (size_t)-1; + } + + if (!decoded_results.has_chain_index || !decoded_results.session_id + || !decoded_results.ciphertext + ) { + session->last_error = OLM_BAD_MESSAGE_FORMAT; + return (size_t)-1; + } + + max_length = cipher->ops->decrypt_max_plaintext_length( + cipher, + decoded_results.ciphertext_length + ); + if (max_plaintext_length < max_length) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + /* pick a megolm instance to use. If we're at or beyond the latest ratchet + * value, use that */ + if ((int32_t)(decoded_results.chain_index - session->latest_ratchet.counter) >= 0) { + megolm = &session->latest_ratchet; + } else if ((int32_t)(decoded_results.chain_index - session->initial_ratchet.counter) < 0) { + /* the counter is before our intial ratchet - we can't decode this. */ + session->last_error = OLM_BAD_CHAIN_INDEX; + return (size_t)-1; + } else { + /* otherwise, start from the initial megolm. Take a copy so that we + * don't overwrite the initial megolm */ + tmp_megolm = session->initial_ratchet; + megolm = &tmp_megolm; + } + + megolm_advance_to(megolm, decoded_results.chain_index); + + /* now try checking the mac, and decrypting */ + r = cipher->ops->decrypt( + cipher, + megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH, + message, raw_message_length, + decoded_results.ciphertext, decoded_results.ciphertext_length, + plaintext, max_plaintext_length + ); + + memset(&tmp_megolm, 0, sizeof(tmp_megolm)); + if (r == (size_t)-1) { + session->last_error = OLM_BAD_MESSAGE_MAC; + return r; + } + + return r; +} diff --git a/src/message.cpp b/src/message.cpp index df0c7bb..ec44262 100644 --- a/src/message.cpp +++ b/src/message.cpp @@ -363,3 +363,45 @@ void _olm_encode_group_message( pos = encode(pos, COUNTER_TAG, chain_index); pos = encode(pos, CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length); } + +void _olm_decode_group_message( + const uint8_t *input, size_t input_length, + size_t mac_length, + struct _OlmDecodeGroupMessageResults *results +) { + std::uint8_t const * pos = input; + std::uint8_t const * end = input + input_length - mac_length; + std::uint8_t const * unknown = nullptr; + + results->session_id = nullptr; + results->session_id_length = 0; + bool has_chain_index = false; + results->chain_index = 0; + results->ciphertext = nullptr; + results->ciphertext_length = 0; + + if (pos == end) return; + if (input_length < mac_length) return; + results->version = *(pos++); + + while (pos != end) { + pos = decode( + pos, end, GROUP_SESSION_ID_TAG, + results->session_id, results->session_id_length + ); + pos = decode( + pos, end, COUNTER_TAG, + results->chain_index, has_chain_index + ); + pos = decode( + pos, end, CIPHERTEXT_TAG, + results->ciphertext, results->ciphertext_length + ); + if (unknown == pos) { + pos = skip_unknown(pos, end); + } + unknown = pos; + } + + results->has_chain_index = (int)has_chain_index; +} diff --git a/tests/test_group_session.cpp b/tests/test_group_session.cpp index b9fe1ef..5bbdc9d 100644 --- a/tests/test_group_session.cpp +++ b/tests/test_group_session.cpp @@ -12,6 +12,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "olm/inbound_group_session.h" #include "olm/outbound_group_session.h" #include "unittest.hh" @@ -19,11 +20,10 @@ int main() { { - TestCase test_case("Pickle outbound group"); size_t size = olm_outbound_group_session_size(); - void *memory = alloca(size); + uint8_t memory[size]; OlmOutboundGroupSession *session = olm_outbound_group_session(memory); size_t pickle_length = olm_pickle_outbound_group_session_length(session); @@ -61,9 +61,9 @@ int main() { "0123456789ABDEF0123456789ABCDEF"; - + /* build the outbound session */ size_t size = olm_outbound_group_session_size(); - void *memory = alloca(size); + uint8_t memory[size]; OlmOutboundGroupSession *session = olm_outbound_group_session(memory); assert_equals((size_t)132, @@ -73,18 +73,48 @@ int main() { session, random_bytes, sizeof(random_bytes)); assert_equals((size_t)0, res); + assert_equals(0U, olm_outbound_group_session_message_index(session)); + size_t session_key_len = olm_outbound_group_session_key_length(session); + uint8_t session_key[session_key_len]; + olm_outbound_group_session_key(session, session_key, session_key_len); + + + /* encode the message */ uint8_t plaintext[] = "Message"; size_t plaintext_length = sizeof(plaintext) - 1; size_t msglen = olm_group_encrypt_message_length( session, plaintext_length); - uint8_t *msg = (uint8_t *)alloca(msglen); + uint8_t msg[msglen]; res = olm_group_encrypt(session, plaintext, plaintext_length, msg, msglen); assert_equals(msglen, res); + assert_equals(1U, olm_outbound_group_session_message_index(session)); + + + /* build the inbound session */ + size = olm_inbound_group_session_size(); + uint8_t inbound_session_memory[size]; + OlmInboundGroupSession *inbound_session = + olm_inbound_group_session(inbound_session_memory); + + res = olm_init_inbound_group_session( + inbound_session, 0U, session_key, session_key_len); + assert_equals((size_t)0, res); - // TODO: decode the message + /* decode the message */ + + /* olm_group_decrypt_max_plaintext_length destroys the input so we have to + copy it. */ + uint8_t msgcopy[msglen]; + memcpy(msgcopy, msg, msglen); + size = olm_group_decrypt_max_plaintext_length(inbound_session, msgcopy, msglen); + uint8_t plaintext_buf[size]; + res = olm_group_decrypt(inbound_session, msg, msglen, + plaintext_buf, size); + assert_equals(plaintext_length, res); + assert_equals(plaintext, plaintext_buf, res); } } diff --git a/tests/test_message.cpp b/tests/test_message.cpp index e2385ea..5fec9e0 100644 --- a/tests/test_message.cpp +++ b/tests/test_message.cpp @@ -97,4 +97,26 @@ assert_equals(message2, output, 35); assert_equals(output+sizeof(expected)-1, ciphertext_ptr); } /* group message encode test */ +{ + TestCase test_case("Group message decode test"); + + struct _OlmDecodeGroupMessageResults results; + std::uint8_t message[] = + "\x03" + "\x2A\x09sessionid" + "\x10\xc8\x01" + "\x22\x0A" "ciphertext" + "hmacsha2"; + + const uint8_t expected_session_id[] = "sessionid"; + + _olm_decode_group_message(message, sizeof(message)-1, 8, &results); + assert_equals(std::uint8_t(3), results.version); + assert_equals(std::size_t(9), results.session_id_length); + assert_equals(expected_session_id, results.session_id, 9); + assert_equals(1, results.has_chain_index); + assert_equals(std::uint32_t(200), results.chain_index); + assert_equals(std::size_t(10), results.ciphertext_length); + assert_equals(ciphertext, results.ciphertext, 10); +} /* group message decode test */ } -- cgit v1.2.3 From a073d12d8367d27db97751d46b766e8480fd39e4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 18 May 2016 18:03:59 +0100 Subject: Support for pickling inbound group sessions --- src/inbound_group_session.c | 76 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_group_session.cpp | 33 ++++++++++++++++++- 2 files changed, 108 insertions(+), 1 deletion(-) diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c index 4796414..34908a9 100644 --- a/src/inbound_group_session.c +++ b/src/inbound_group_session.c @@ -22,8 +22,12 @@ #include "olm/error.h" #include "olm/megolm.h" #include "olm/message.h" +#include "olm/pickle.h" +#include "olm/pickle_encoding.h" + #define OLM_PROTOCOL_VERSION 3 +#define PICKLE_VERSION 1 struct OlmInboundGroupSession { /** our earliest known ratchet value */ @@ -86,6 +90,78 @@ size_t olm_init_inbound_group_session( return 0; } +static size_t raw_pickle_length( + const OlmInboundGroupSession *session +) { + size_t length = 0; + length += _olm_pickle_uint32_length(PICKLE_VERSION); + length += megolm_pickle_length(&session->initial_ratchet); + length += megolm_pickle_length(&session->latest_ratchet); + return length; +} + +size_t olm_pickle_inbound_group_session_length( + const OlmInboundGroupSession *session +) { + return _olm_enc_output_length(raw_pickle_length(session)); +} + +size_t olm_pickle_inbound_group_session( + OlmInboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + size_t raw_length = raw_pickle_length(session); + uint8_t *pos; + + if (pickled_length < _olm_enc_output_length(raw_length)) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + pos = _olm_enc_output_pos(pickled, raw_length); + pos = _olm_pickle_uint32(pos, PICKLE_VERSION); + pos = megolm_pickle(&session->initial_ratchet, pos); + pos = megolm_pickle(&session->latest_ratchet, pos); + + return _olm_enc_output(key, key_length, pickled, raw_length); +} + +size_t olm_unpickle_inbound_group_session( + OlmInboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + const uint8_t *pos; + const uint8_t *end; + uint32_t pickle_version; + + size_t raw_length = _olm_enc_input( + key, key_length, pickled, pickled_length, &(session->last_error) + ); + if (raw_length == (size_t)-1) { + return raw_length; + } + + pos = pickled; + end = pos + raw_length; + pos = _olm_unpickle_uint32(pos, end, &pickle_version); + if (pickle_version != PICKLE_VERSION) { + session->last_error = OLM_UNKNOWN_PICKLE_VERSION; + return (size_t)-1; + } + pos = megolm_unpickle(&session->initial_ratchet, pos, end); + pos = megolm_unpickle(&session->latest_ratchet, pos, end); + + if (end != pos) { + /* We had the wrong number of bytes in the input. */ + session->last_error = OLM_CORRUPTED_PICKLE; + return (size_t)-1; + } + + return pickled_length; +} + size_t olm_group_decrypt_max_plaintext_length( OlmInboundGroupSession *session, uint8_t * message, size_t message_length diff --git a/tests/test_group_session.cpp b/tests/test_group_session.cpp index 5bbdc9d..4a82154 100644 --- a/tests/test_group_session.cpp +++ b/tests/test_group_session.cpp @@ -20,7 +20,7 @@ int main() { { - TestCase test_case("Pickle outbound group"); + TestCase test_case("Pickle outbound group session"); size_t size = olm_outbound_group_session_size(); uint8_t memory[size]; @@ -50,6 +50,37 @@ int main() { } +{ + TestCase test_case("Pickle inbound group session"); + + size_t size = olm_inbound_group_session_size(); + uint8_t memory[size]; + OlmInboundGroupSession *session = olm_inbound_group_session(memory); + + size_t pickle_length = olm_pickle_inbound_group_session_length(session); + uint8_t pickle1[pickle_length]; + olm_pickle_inbound_group_session(session, + "secret_key", 10, + pickle1, pickle_length); + uint8_t pickle2[pickle_length]; + memcpy(pickle2, pickle1, pickle_length); + + uint8_t buffer2[size]; + OlmInboundGroupSession *session2 = olm_inbound_group_session(buffer2); + size_t res = olm_unpickle_inbound_group_session(session2, + "secret_key", 10, + pickle2, pickle_length); + assert_not_equals((size_t)-1, res); + assert_equals(pickle_length, + olm_pickle_inbound_group_session_length(session2)); + olm_pickle_inbound_group_session(session2, + "secret_key", 10, + pickle2, pickle_length); + + assert_equals(pickle1, pickle2, pickle_length); +} + + { TestCase test_case("Group message send/receive"); -- cgit v1.2.3 From fc4756ddf17f536912a89a4ffcf90a309c236ced Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 19 May 2016 07:53:07 +0100 Subject: Fix up some names, and protobuf tags Make names (of session_key and message_index) more consistent. Use our own protobuf tags rather than trying to piggyback on the one-to-one structure. --- include/olm/error.h | 8 ++++++-- include/olm/inbound_group_session.h | 8 ++++---- include/olm/message.h | 8 ++++---- src/error.c | 2 ++ src/inbound_group_session.c | 12 ++++++------ src/message.cpp | 26 ++++++++++++++------------ tests/test_message.cpp | 16 ++++++++-------- 7 files changed, 44 insertions(+), 36 deletions(-) diff --git a/include/olm/error.h b/include/olm/error.h index 3f74992..98d2cf5 100644 --- a/include/olm/error.h +++ b/include/olm/error.h @@ -32,8 +32,12 @@ enum OlmErrorCode { OLM_UNKNOWN_PICKLE_VERSION = 9, /*!< The pickled object is too new */ OLM_CORRUPTED_PICKLE = 10, /*!< The pickled object couldn't be decoded */ - OLM_BAD_RATCHET_KEY = 11, - OLM_BAD_CHAIN_INDEX = 12, + OLM_BAD_SESSION_KEY = 11, /*!< Attempt to initialise an inbound group + session from an invalid session key */ + OLM_UNKNOWN_MESSAGE_INDEX = 12, /*!< Attempt to decode a message whose + * index is earlier than our earliest + * known session key. + */ /* remember to update the list of string constants in error.c when updating * this list. */ diff --git a/include/olm/inbound_group_session.h b/include/olm/inbound_group_session.h index 4cf4ac4..e24f377 100644 --- a/include/olm/inbound_group_session.h +++ b/include/olm/inbound_group_session.h @@ -91,7 +91,7 @@ size_t olm_unpickle_inbound_group_session( * error code. The last_error will be: * * * OLM_INVALID_BASE64 if the session_key is not valid base64 - * * OLM_BAD_RATCHET_KEY if the session_key is invalid + * * OLM_BAD_SESSION_KEY if the session_key is invalid */ size_t olm_init_inbound_group_session( OlmInboundGroupSession *session, @@ -129,9 +129,9 @@ size_t olm_group_decrypt_max_plaintext_length( * * OLM_BAD_MESSAGE_VERSION if the message was encrypted with an unsupported * version of the protocol * * OLM_BAD_MESSAGE_FORMAT if the message headers could not be decoded - * * OLM_BAD_MESSAGE_MAC if the message could not be verified - * * OLM_BAD_CHAIN_INDEX if we do not have a ratchet key corresponding to the - * message's index (ie, it was sent before the ratchet key was shared with + * * OLM_BAD_MESSAGE_MAC if the message could not be verified + * * OLM_UNKNOWN_MESSAGE_INDEX if we do not have a session key corresponding to the + * message's index (ie, it was sent before the session key was shared with * us) */ size_t olm_group_decrypt( diff --git a/include/olm/message.h b/include/olm/message.h index bd7aec3..cff15f3 100644 --- a/include/olm/message.h +++ b/include/olm/message.h @@ -47,7 +47,7 @@ size_t _olm_encode_group_message_length( * version: version number of the olm protocol * session_id: group session identifier * session_id_length: length of session_id - * chain_index: message index + * message_index: message index * ciphertext_length: length of the ciphertext * output: where to write the output. Should be at least * olm_encode_group_message_length() bytes long. @@ -58,7 +58,7 @@ void _olm_encode_group_message( uint8_t version, const uint8_t *session_id, size_t session_id_length, - uint32_t chain_index, + uint32_t message_index, size_t ciphertext_length, uint8_t *output, uint8_t **ciphertext_ptr @@ -69,8 +69,8 @@ struct _OlmDecodeGroupMessageResults { uint8_t version; const uint8_t *session_id; size_t session_id_length; - uint32_t chain_index; - int has_chain_index; + uint32_t message_index; + int has_message_index; const uint8_t *ciphertext; size_t ciphertext_length; }; diff --git a/src/error.c b/src/error.c index 0690856..bd8a39d 100644 --- a/src/error.c +++ b/src/error.c @@ -27,6 +27,8 @@ static const char * ERRORS[] = { "BAD_ACCOUNT_KEY", "UNKNOWN_PICKLE_VERSION", "CORRUPTED_PICKLE", + "BAD_SESSION_KEY", + "UNKNOWN_MESSAGE_INDEX", }; const char * _olm_error_to_string(enum OlmErrorCode error) diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c index 34908a9..cc6ba5e 100644 --- a/src/inbound_group_session.c +++ b/src/inbound_group_session.c @@ -78,7 +78,7 @@ size_t olm_init_inbound_group_session( } if (raw_length != MEGOLM_RATCHET_LENGTH) { - session->last_error = OLM_BAD_RATCHET_KEY; + session->last_error = OLM_BAD_SESSION_KEY; return (size_t)-1; } @@ -223,7 +223,7 @@ size_t olm_group_decrypt( return (size_t)-1; } - if (!decoded_results.has_chain_index || !decoded_results.session_id + if (!decoded_results.has_message_index || !decoded_results.session_id || !decoded_results.ciphertext ) { session->last_error = OLM_BAD_MESSAGE_FORMAT; @@ -241,11 +241,11 @@ size_t olm_group_decrypt( /* pick a megolm instance to use. If we're at or beyond the latest ratchet * value, use that */ - if ((int32_t)(decoded_results.chain_index - session->latest_ratchet.counter) >= 0) { + if ((int32_t)(decoded_results.message_index - session->latest_ratchet.counter) >= 0) { megolm = &session->latest_ratchet; - } else if ((int32_t)(decoded_results.chain_index - session->initial_ratchet.counter) < 0) { + } else if ((int32_t)(decoded_results.message_index - session->initial_ratchet.counter) < 0) { /* the counter is before our intial ratchet - we can't decode this. */ - session->last_error = OLM_BAD_CHAIN_INDEX; + session->last_error = OLM_UNKNOWN_MESSAGE_INDEX; return (size_t)-1; } else { /* otherwise, start from the initial megolm. Take a copy so that we @@ -254,7 +254,7 @@ size_t olm_group_decrypt( megolm = &tmp_megolm; } - megolm_advance_to(megolm, decoded_results.chain_index); + megolm_advance_to(megolm, decoded_results.message_index); /* now try checking the mac, and decrypting */ r = cipher->ops->decrypt( diff --git a/src/message.cpp b/src/message.cpp index ec44262..ab4300e 100644 --- a/src/message.cpp +++ b/src/message.cpp @@ -328,17 +328,19 @@ void olm::decode_one_time_key_message( -static std::uint8_t const GROUP_SESSION_ID_TAG = 052; +static const std::uint8_t GROUP_SESSION_ID_TAG = 012; +static const std::uint8_t GROUP_MESSAGE_INDEX_TAG = 020; +static const std::uint8_t GROUP_CIPHERTEXT_TAG = 032; size_t _olm_encode_group_message_length( size_t group_session_id_length, - uint32_t chain_index, + uint32_t message_index, size_t ciphertext_length, size_t mac_length ) { size_t length = VERSION_LENGTH; length += 1 + varstring_length(group_session_id_length); - length += 1 + varint_length(chain_index); + length += 1 + varint_length(message_index); length += 1 + varstring_length(ciphertext_length); length += mac_length; return length; @@ -349,7 +351,7 @@ void _olm_encode_group_message( uint8_t version, const uint8_t *session_id, size_t session_id_length, - uint32_t chain_index, + uint32_t message_index, size_t ciphertext_length, uint8_t *output, uint8_t **ciphertext_ptr @@ -360,8 +362,8 @@ void _olm_encode_group_message( *(pos++) = version; pos = encode(pos, GROUP_SESSION_ID_TAG, session_id_pos, session_id_length); std::memcpy(session_id_pos, session_id, session_id_length); - pos = encode(pos, COUNTER_TAG, chain_index); - pos = encode(pos, CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length); + pos = encode(pos, GROUP_MESSAGE_INDEX_TAG, message_index); + pos = encode(pos, GROUP_CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length); } void _olm_decode_group_message( @@ -375,8 +377,8 @@ void _olm_decode_group_message( results->session_id = nullptr; results->session_id_length = 0; - bool has_chain_index = false; - results->chain_index = 0; + bool has_message_index = false; + results->message_index = 0; results->ciphertext = nullptr; results->ciphertext_length = 0; @@ -390,11 +392,11 @@ void _olm_decode_group_message( results->session_id, results->session_id_length ); pos = decode( - pos, end, COUNTER_TAG, - results->chain_index, has_chain_index + pos, end, GROUP_MESSAGE_INDEX_TAG, + results->message_index, has_message_index ); pos = decode( - pos, end, CIPHERTEXT_TAG, + pos, end, GROUP_CIPHERTEXT_TAG, results->ciphertext, results->ciphertext_length ); if (unknown == pos) { @@ -403,5 +405,5 @@ void _olm_decode_group_message( unknown = pos; } - results->has_chain_index = (int)has_chain_index; + results->has_message_index = (int)has_message_index; } diff --git a/tests/test_message.cpp b/tests/test_message.cpp index 5fec9e0..30c10a0 100644 --- a/tests/test_message.cpp +++ b/tests/test_message.cpp @@ -89,9 +89,9 @@ assert_equals(message2, output, 35); uint8_t expected[] = "\x03" - "\x2A\x09sessionid" - "\x10\xc8\x01" - "\x22\x0a"; + "\x0A\x09sessionid" + "\x10\xC8\x01" + "\x1A\x0A"; assert_equals(expected, output, sizeof(expected)-1); assert_equals(output+sizeof(expected)-1, ciphertext_ptr); @@ -103,9 +103,9 @@ assert_equals(message2, output, 35); struct _OlmDecodeGroupMessageResults results; std::uint8_t message[] = "\x03" - "\x2A\x09sessionid" - "\x10\xc8\x01" - "\x22\x0A" "ciphertext" + "\x0A\x09sessionid" + "\x10\xC8\x01" + "\x1A\x0A" "ciphertext" "hmacsha2"; const uint8_t expected_session_id[] = "sessionid"; @@ -114,8 +114,8 @@ assert_equals(message2, output, 35); assert_equals(std::uint8_t(3), results.version); assert_equals(std::size_t(9), results.session_id_length); assert_equals(expected_session_id, results.session_id, 9); - assert_equals(1, results.has_chain_index); - assert_equals(std::uint32_t(200), results.chain_index); + assert_equals(1, results.has_message_index); + assert_equals(std::uint32_t(200), results.message_index); assert_equals(std::size_t(10), results.ciphertext_length); assert_equals(ciphertext, results.ciphertext, 10); } /* group message decode test */ -- cgit v1.2.3 From 846ab858a6dd2e962d3e110147f4274416026f5a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 19 May 2016 12:01:15 +0100 Subject: Python wrapper: support for inbound group sessions --- python/olm/__init__.py | 1 + python/olm/__main__.py | 76 +++++++++++++++++++++++++++++--- python/olm/inbound_group_session.py | 86 +++++++++++++++++++++++++++++++++++++ python/test_olm.sh | 6 ++- 4 files changed, 162 insertions(+), 7 deletions(-) create mode 100644 python/olm/inbound_group_session.py diff --git a/python/olm/__init__.py b/python/olm/__init__.py index 8681d12..31b29b9 100644 --- a/python/olm/__init__.py +++ b/python/olm/__init__.py @@ -1,3 +1,4 @@ from .account import Account from .session import Session from .outbound_group_session import OutboundGroupSession +from .inbound_group_session import InboundGroupSession diff --git a/python/olm/__main__.py b/python/olm/__main__.py index 8bb2419..3fcf24d 100755 --- a/python/olm/__main__.py +++ b/python/olm/__main__.py @@ -210,14 +210,43 @@ def build_arg_parser(): outbound_group.add_argument("session_file", help="Local group session file") outbound_group.set_defaults(func=do_outbound_group) + group_credentials = commands.add_parser("group_credentials", help="Export the current outbound group session credentials") + group_credentials.add_argument("session_file", help="Local outbound group session file") + group_credentials.add_argument("credentials_file", help="File to write credentials to (default stdout)", + type=argparse.FileType('w'), nargs='?', + default=sys.stdout) + group_credentials.set_defaults(func=do_group_credentials) + group_encrypt = commands.add_parser("group_encrypt", help="Encrypt a group message") - group_encrypt.add_argument("session_file", help="Local group session file") - group_encrypt.add_argument("plaintext_file", help="Plaintext", - type=argparse.FileType('rb'), default=sys.stdin) - group_encrypt.add_argument("message_file", help="Message", - type=argparse.FileType('wb'), default=sys.stdout) + group_encrypt.add_argument("session_file", help="Local outbound group session file") + group_encrypt.add_argument("plaintext_file", help="Plaintext file (default stdin)", + type=argparse.FileType('rb'), nargs='?', + default=sys.stdin) + group_encrypt.add_argument("message_file", help="Message file (default stdout)", + type=argparse.FileType('w'), nargs='?', + default=sys.stdout) group_encrypt.set_defaults(func=do_group_encrypt) + inbound_group = commands.add_parser( + "inbound_group", + help=("Create an inbound group session based on credentials from an "+ + "outbound group session")) + inbound_group.add_argument("session_file", help="Local inbound group session file") + inbound_group.add_argument("credentials_file", + help="File to read credentials from (default stdin)", + type=argparse.FileType('r'), nargs='?', + default=sys.stdin) + inbound_group.set_defaults(func=do_inbound_group) + + group_decrypt = commands.add_parser("group_decrypt", help="Decrypt a group message") + group_decrypt.add_argument("session_file", help="Local inbound group session file") + group_decrypt.add_argument("message_file", help="Message file (default stdin)", + type=argparse.FileType('r'), nargs='?', + default=sys.stdin) + group_decrypt.add_argument("plaintext_file", help="Plaintext file (default stdout)", + type=argparse.FileType('wb'), nargs='?', + default=sys.stdout) + group_decrypt.set_defaults(func=do_group_decrypt) return parser def do_outbound_group(args): @@ -240,6 +269,43 @@ def do_group_encrypt(args): f.write(session.pickle(args.key)) args.message_file.write(message) +def do_group_credentials(args): + session = OutboundGroupSession() + with open(args.session_file, "rb") as f: + session.unpickle(args.key, f.read()) + result = { + 'message_index': session.message_index(), + 'session_key': session.session_key(), + } + json.dump(result, args.credentials_file, indent=4) + +def do_inbound_group(args): + if os.path.exists(args.session_file): + sys.stderr.write("Session %r file already exists\n" % ( + args.session_file, + )) + sys.exit(1) + credentials = json.load(args.credentials_file) + for k in ('message_index', 'session_key'): + if not k in credentials: + sys.stderr.write("Credentials file is missing %s\n" % k) + sys.exit(1); + + session = InboundGroupSession() + session.init(credentials['message_index'], credentials['session_key']) + with open(args.session_file, "wb") as f: + f.write(session.pickle(args.key)) + +def do_group_decrypt(args): + session = InboundGroupSession() + with open(args.session_file, "rb") as f: + session.unpickle(args.key, f.read()) + message = args.message_file.read() + plaintext = session.decrypt(message) + with open(args.session_file, "wb") as f: + f.write(session.pickle(args.key)) + args.plaintext_file.write(plaintext) + if __name__ == '__main__': parser = build_arg_parser() args = parser.parse_args() diff --git a/python/olm/inbound_group_session.py b/python/olm/inbound_group_session.py new file mode 100644 index 0000000..6c01095 --- /dev/null +++ b/python/olm/inbound_group_session.py @@ -0,0 +1,86 @@ +import json + +from ._base import * + +lib.olm_inbound_group_session_size.argtypes = [] +lib.olm_inbound_group_session_size.restype = c_size_t + +lib.olm_inbound_group_session.argtypes = [c_void_p] +lib.olm_inbound_group_session.restype = c_void_p + +lib.olm_inbound_group_session_last_error.argtypes = [c_void_p] +lib.olm_inbound_group_session_last_error.restype = c_char_p + +def inbound_group_session_errcheck(res, func, args): + if res == ERR: + raise OlmError("%s: %s" % ( + func.__name__, lib.olm_inbound_group_session_last_error(args[0]) + )) + return res + + +def inbound_group_session_function(func, *types): + func.argtypes = (c_void_p,) + types + func.restypes = c_size_t + func.errcheck = inbound_group_session_errcheck + + +inbound_group_session_function( + lib.olm_pickle_inbound_group_session, c_void_p, c_size_t, c_void_p, c_size_t +) +inbound_group_session_function( + lib.olm_unpickle_inbound_group_session, c_void_p, c_size_t, c_void_p, c_size_t +) + +inbound_group_session_function( + lib.olm_init_inbound_group_session, c_uint32, c_void_p, c_size_t +) + +inbound_group_session_function( + lib.olm_group_decrypt_max_plaintext_length, c_void_p, c_size_t +) +inbound_group_session_function( + lib.olm_group_decrypt, + c_void_p, c_size_t, # message + c_void_p, c_size_t, # plaintext +) + +class InboundGroupSession(object): + def __init__(self): + self.buf = create_string_buffer(lib.olm_inbound_group_session_size()) + self.ptr = lib.olm_inbound_group_session(self.buf) + + def pickle(self, key): + key_buffer = create_string_buffer(key) + pickle_length = lib.olm_pickle_inbound_group_session_length(self.ptr) + pickle_buffer = create_string_buffer(pickle_length) + lib.olm_pickle_inbound_group_session( + self.ptr, key_buffer, len(key), pickle_buffer, pickle_length + ) + return pickle_buffer.raw + + def unpickle(self, key, pickle): + key_buffer = create_string_buffer(key) + pickle_buffer = create_string_buffer(pickle) + lib.olm_unpickle_inbound_group_session( + self.ptr, key_buffer, len(key), pickle_buffer, len(pickle) + ) + + def init(self, message_index, session_key): + key_buffer = create_string_buffer(session_key) + lib.olm_init_inbound_group_session( + self.ptr, message_index, key_buffer, len(session_key) + ) + + def decrypt(self, message): + message_buffer = create_string_buffer(message) + max_plaintext_length = lib.olm_group_decrypt_max_plaintext_length( + self.ptr, message_buffer, len(message) + ) + plaintext_buffer = create_string_buffer(max_plaintext_length) + message_buffer = create_string_buffer(message) + plaintext_length = lib.olm_group_decrypt( + self.ptr, message_buffer, len(message), + plaintext_buffer, max_plaintext_length + ) + return plaintext_buffer.raw[:plaintext_length] diff --git a/python/test_olm.sh b/python/test_olm.sh index 0ea1623..da69581 100755 --- a/python/test_olm.sh +++ b/python/test_olm.sh @@ -7,10 +7,11 @@ ALICE_SESSION=alice.session ALICE_GROUP_SESSION=alice.group_session BOB_ACCOUNT=bob.account BOB_SESSION=bob.session +BOB_GROUP_SESSION=bob.group_session rm $ALICE_ACCOUNT $BOB_ACCOUNT rm $ALICE_SESSION $BOB_SESSION -rm $ALICE_GROUP_SESSION +rm $ALICE_GROUP_SESSION $BOB_GROUP_SESSION $OLM create_account $ALICE_ACCOUNT $OLM create_account $BOB_ACCOUNT @@ -27,4 +28,5 @@ echo "Hello world" | $OLM encrypt $ALICE_SESSION - - | $OLM inbound $BOB_ACCOUNT ### group sessions $OLM outbound_group $ALICE_GROUP_SESSION -echo "Hello world" | $OLM group_encrypt $ALICE_GROUP_SESSION - - +$OLM group_credentials $ALICE_GROUP_SESSION | $OLM inbound_group $BOB_GROUP_SESSION +echo "Hello group" | $OLM group_encrypt $ALICE_GROUP_SESSION - - | $OLM group_decrypt $BOB_GROUP_SESSION -- cgit v1.2.3 From 173cbe112c139de0bd1a69dce5a03db360dc5abc Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 20 May 2016 12:40:59 +0100 Subject: Avoid relying on uint -> int casting behaviour The behaviour when casting from a uint32_t which has overflowed (so has the top bit set) to int32_t is implementation-defined, so let's avoid relying on it. --- src/inbound_group_session.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c index cc6ba5e..b8f762d 100644 --- a/src/inbound_group_session.c +++ b/src/inbound_group_session.c @@ -241,9 +241,9 @@ size_t olm_group_decrypt( /* pick a megolm instance to use. If we're at or beyond the latest ratchet * value, use that */ - if ((int32_t)(decoded_results.message_index - session->latest_ratchet.counter) >= 0) { + if ((decoded_results.message_index - session->latest_ratchet.counter) < (1U << 31)) { megolm = &session->latest_ratchet; - } else if ((int32_t)(decoded_results.message_index - session->initial_ratchet.counter) < 0) { + } else if ((decoded_results.message_index - session->initial_ratchet.counter) >= (1U << 31)) { /* the counter is before our intial ratchet - we can't decode this. */ session->last_error = OLM_UNKNOWN_MESSAGE_INDEX; return (size_t)-1; -- cgit v1.2.3 From fa1e9446ac2b4d26dd592813ce0a372565df4c93 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 20 May 2016 12:35:59 +0100 Subject: Use _olm_unset instead of memset memset is at risk of being optimised away, so use _olm_unset instead. --- src/inbound_group_session.c | 7 ++++--- src/outbound_group_session.c | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c index b8f762d..6cded75 100644 --- a/src/inbound_group_session.c +++ b/src/inbound_group_session.c @@ -21,6 +21,7 @@ #include "olm/cipher.h" #include "olm/error.h" #include "olm/megolm.h" +#include "olm/memory.h" #include "olm/message.h" #include "olm/pickle.h" #include "olm/pickle_encoding.h" @@ -60,7 +61,7 @@ const char *olm_inbound_group_session_last_error( size_t olm_clear_inbound_group_session( OlmInboundGroupSession *session ) { - memset(session, 0, sizeof(OlmInboundGroupSession)); + _olm_unset(session, sizeof(OlmInboundGroupSession)); return sizeof(OlmInboundGroupSession); } @@ -85,7 +86,7 @@ size_t olm_init_inbound_group_session( _olm_decode_base64(session_key, session_key_length, key_buf); megolm_init(&session->initial_ratchet, key_buf, message_index); megolm_init(&session->latest_ratchet, key_buf, message_index); - memset(key_buf, 0, MEGOLM_RATCHET_LENGTH); + _olm_unset(key_buf, MEGOLM_RATCHET_LENGTH); return 0; } @@ -265,7 +266,7 @@ size_t olm_group_decrypt( plaintext, max_plaintext_length ); - memset(&tmp_megolm, 0, sizeof(tmp_megolm)); + _olm_unset(&tmp_megolm, sizeof(tmp_megolm)); if (r == (size_t)-1) { session->last_error = OLM_BAD_MESSAGE_MAC; return r; diff --git a/src/outbound_group_session.c b/src/outbound_group_session.c index fadf949..cf7d32c 100644 --- a/src/outbound_group_session.c +++ b/src/outbound_group_session.c @@ -22,6 +22,7 @@ #include "olm/cipher.h" #include "olm/error.h" #include "olm/megolm.h" +#include "olm/memory.h" #include "olm/message.h" #include "olm/pickle.h" #include "olm/pickle_encoding.h" @@ -63,7 +64,7 @@ const char *olm_outbound_group_session_last_error( size_t olm_clear_outbound_group_session( OlmOutboundGroupSession *session ) { - memset(session, 0, sizeof(OlmOutboundGroupSession)); + _olm_unset(session, sizeof(OlmOutboundGroupSession)); return sizeof(OlmOutboundGroupSession); } -- cgit v1.2.3 From a919a149fbb192e3fae7aba921ca28e02d9c0d10 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 24 May 2016 14:54:01 +0100 Subject: Update megolm_cipher as a global struct Initialise megolm_cipher via the preprocessor macro, instead of with a function. --- include/olm/megolm.h | 4 ++-- src/inbound_group_session.c | 18 ++++++++---------- src/megolm.c | 15 +++------------ src/outbound_group_session.c | 16 +++++++--------- 4 files changed, 20 insertions(+), 33 deletions(-) diff --git a/include/olm/megolm.h b/include/olm/megolm.h index 831c6fb..e4e5d0b 100644 --- a/include/olm/megolm.h +++ b/include/olm/megolm.h @@ -49,11 +49,11 @@ typedef struct Megolm { /** - * Get the cipher used in megolm-backed conversations + * The cipher used in megolm-backed conversations * * (AES256 + SHA256, with keys based on an HKDF with info of MEGOLM_KEYS) */ -const struct _olm_cipher *megolm_cipher(); +extern const struct _olm_cipher *megolm_cipher; /** * initialize the megolm ratchet. random_data should be at least diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c index 6cded75..b6894c1 100644 --- a/src/inbound_group_session.c +++ b/src/inbound_group_session.c @@ -168,7 +168,6 @@ size_t olm_group_decrypt_max_plaintext_length( uint8_t * message, size_t message_length ) { size_t r; - const struct _olm_cipher *cipher = megolm_cipher(); struct _OlmDecodeGroupMessageResults decoded_results; r = _olm_decode_base64(message, message_length, message); @@ -179,7 +178,7 @@ size_t olm_group_decrypt_max_plaintext_length( _olm_decode_group_message( message, message_length, - cipher->ops->mac_length(cipher), + megolm_cipher->ops->mac_length(megolm_cipher), &decoded_results); if (decoded_results.version != OLM_PROTOCOL_VERSION) { @@ -192,8 +191,8 @@ size_t olm_group_decrypt_max_plaintext_length( return (size_t)-1; } - return cipher->ops->decrypt_max_plaintext_length( - cipher, decoded_results.ciphertext_length); + return megolm_cipher->ops->decrypt_max_plaintext_length( + megolm_cipher, decoded_results.ciphertext_length); } @@ -203,7 +202,6 @@ size_t olm_group_decrypt( uint8_t * plaintext, size_t max_plaintext_length ) { struct _OlmDecodeGroupMessageResults decoded_results; - const struct _olm_cipher *cipher = megolm_cipher(); size_t max_length, raw_message_length, r; Megolm *megolm; Megolm tmp_megolm; @@ -216,7 +214,7 @@ size_t olm_group_decrypt( _olm_decode_group_message( message, raw_message_length, - cipher->ops->mac_length(cipher), + megolm_cipher->ops->mac_length(megolm_cipher), &decoded_results); if (decoded_results.version != OLM_PROTOCOL_VERSION) { @@ -231,8 +229,8 @@ size_t olm_group_decrypt( return (size_t)-1; } - max_length = cipher->ops->decrypt_max_plaintext_length( - cipher, + max_length = megolm_cipher->ops->decrypt_max_plaintext_length( + megolm_cipher, decoded_results.ciphertext_length ); if (max_plaintext_length < max_length) { @@ -258,8 +256,8 @@ size_t olm_group_decrypt( megolm_advance_to(megolm, decoded_results.message_index); /* now try checking the mac, and decrypting */ - r = cipher->ops->decrypt( - cipher, + r = megolm_cipher->ops->decrypt( + megolm_cipher, megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH, message, raw_message_length, decoded_results.ciphertext, decoded_results.ciphertext_length, diff --git a/src/megolm.c b/src/megolm.c index 7567894..110f939 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -22,18 +22,9 @@ #include "olm/crypto.h" #include "olm/pickle.h" -const struct _olm_cipher *megolm_cipher() { - static const uint8_t CIPHER_KDF_INFO[] = "MEGOLM_KEYS"; - static struct _olm_cipher *cipher; - static struct _olm_cipher_aes_sha_256 OLM_CIPHER; - if (!cipher) { - cipher = _olm_cipher_aes_sha_256_init( - &OLM_CIPHER, - CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) - 1 - ); - } - return cipher; -} +static const struct _olm_cipher_aes_sha_256 MEGOLM_CIPHER = + OLM_CIPHER_INIT_AES_SHA_256("MEGOLM_KEYS"); +const struct _olm_cipher *megolm_cipher = OLM_CIPHER_BASE(&MEGOLM_CIPHER); /* the seeds used in the HMAC-SHA-256 functions for each part of the ratchet. */ diff --git a/src/outbound_group_session.c b/src/outbound_group_session.c index cf7d32c..9f36ad8 100644 --- a/src/outbound_group_session.c +++ b/src/outbound_group_session.c @@ -179,13 +179,12 @@ static size_t raw_message_length( size_t plaintext_length) { size_t ciphertext_length, mac_length; - const struct _olm_cipher *cipher = megolm_cipher(); - ciphertext_length = cipher->ops->encrypt_ciphertext_length( - cipher, plaintext_length + ciphertext_length = megolm_cipher->ops->encrypt_ciphertext_length( + megolm_cipher, plaintext_length ); - mac_length = cipher->ops->mac_length(cipher); + mac_length = megolm_cipher->ops->mac_length(megolm_cipher); return _olm_encode_group_message_length( GROUP_SESSION_ID_LENGTH, session->ratchet.counter, @@ -210,7 +209,6 @@ size_t olm_group_encrypt( size_t rawmsglen; size_t result; uint8_t *ciphertext_ptr, *message_pos; - const struct _olm_cipher *cipher = megolm_cipher(); rawmsglen = raw_message_length(session, plaintext_length); @@ -219,8 +217,8 @@ size_t olm_group_encrypt( return (size_t)-1; } - ciphertext_length = cipher->ops->encrypt_ciphertext_length( - cipher, + ciphertext_length = megolm_cipher->ops->encrypt_ciphertext_length( + megolm_cipher, plaintext_length ); @@ -240,8 +238,8 @@ size_t olm_group_encrypt( message_pos, &ciphertext_ptr); - result = cipher->ops->encrypt( - cipher, + result = megolm_cipher->ops->encrypt( + megolm_cipher, megolm_get_data(&(session->ratchet)), MEGOLM_RATCHET_LENGTH, plaintext, plaintext_length, ciphertext_ptr, ciphertext_length, -- cgit v1.2.3 From 1b15465c42a88f750a960a0e73f186245f9bba33 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 24 May 2016 16:23:19 +0100 Subject: Separate base64ing from the rest of msg encoding Factor the actual message encoding/decoding and encrypting/decrypting out to separate functions from the top-level functions which do the base64-wrangling. This is particularly helpful in the 'outbound' code-path where the offsets required to allow room to base64-encode make the flow hard to see when it's all inline. --- include/olm/message.h | 4 ++- src/inbound_group_session.c | 64 +++++++++++++++++++++++++++++++------------ src/message.cpp | 3 +- src/outbound_group_session.c | 65 +++++++++++++++++++++++++++++--------------- 4 files changed, 94 insertions(+), 42 deletions(-) diff --git a/include/olm/message.h b/include/olm/message.h index cff15f3..e80d54c 100644 --- a/include/olm/message.h +++ b/include/olm/message.h @@ -53,8 +53,10 @@ size_t _olm_encode_group_message_length( * olm_encode_group_message_length() bytes long. * ciphertext_ptr: returns the address that the ciphertext * should be written to, followed by the MAC. + * + * Returns the size of the message, up to the MAC. */ -void _olm_encode_group_message( +size_t _olm_encode_group_message( uint8_t version, const uint8_t *session_id, size_t session_id_length, diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c index b6894c1..e171205 100644 --- a/src/inbound_group_session.c +++ b/src/inbound_group_session.c @@ -163,19 +163,15 @@ size_t olm_unpickle_inbound_group_session( return pickled_length; } -size_t olm_group_decrypt_max_plaintext_length( +/** + * get the max plaintext length in an un-base64-ed message + */ +static size_t _decrypt_max_plaintext_length( OlmInboundGroupSession *session, uint8_t * message, size_t message_length ) { - size_t r; struct _OlmDecodeGroupMessageResults decoded_results; - r = _olm_decode_base64(message, message_length, message); - if (r == (size_t)-1) { - session->last_error = OLM_INVALID_BASE64; - return r; - } - _olm_decode_group_message( message, message_length, megolm_cipher->ops->mac_length(megolm_cipher), @@ -195,25 +191,38 @@ size_t olm_group_decrypt_max_plaintext_length( megolm_cipher, decoded_results.ciphertext_length); } +size_t olm_group_decrypt_max_plaintext_length( + OlmInboundGroupSession *session, + uint8_t * message, size_t message_length +) { + size_t raw_length; -size_t olm_group_decrypt( + raw_length = _olm_decode_base64(message, message_length, message); + if (raw_length == (size_t)-1) { + session->last_error = OLM_INVALID_BASE64; + return (size_t)-1; + } + + return _decrypt_max_plaintext_length( + session, message, raw_length + ); +} + +/** + * decrypt an un-base64-ed message + */ +static size_t _decrypt( OlmInboundGroupSession *session, uint8_t * message, size_t message_length, uint8_t * plaintext, size_t max_plaintext_length ) { struct _OlmDecodeGroupMessageResults decoded_results; - size_t max_length, raw_message_length, r; + size_t max_length, r; Megolm *megolm; Megolm tmp_megolm; - raw_message_length = _olm_decode_base64(message, message_length, message); - if (raw_message_length == (size_t)-1) { - session->last_error = OLM_INVALID_BASE64; - return (size_t)-1; - } - _olm_decode_group_message( - message, raw_message_length, + message, message_length, megolm_cipher->ops->mac_length(megolm_cipher), &decoded_results); @@ -259,7 +268,7 @@ size_t olm_group_decrypt( r = megolm_cipher->ops->decrypt( megolm_cipher, megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH, - message, raw_message_length, + message, message_length, decoded_results.ciphertext, decoded_results.ciphertext_length, plaintext, max_plaintext_length ); @@ -272,3 +281,22 @@ size_t olm_group_decrypt( return r; } + +size_t olm_group_decrypt( + OlmInboundGroupSession *session, + uint8_t * message, size_t message_length, + uint8_t * plaintext, size_t max_plaintext_length +) { + size_t raw_message_length; + + raw_message_length = _olm_decode_base64(message, message_length, message); + if (raw_message_length == (size_t)-1) { + session->last_error = OLM_INVALID_BASE64; + return (size_t)-1; + } + + return _decrypt( + session, message, raw_message_length, + plaintext, max_plaintext_length + ); +} diff --git a/src/message.cpp b/src/message.cpp index ab4300e..2e841e5 100644 --- a/src/message.cpp +++ b/src/message.cpp @@ -347,7 +347,7 @@ size_t _olm_encode_group_message_length( } -void _olm_encode_group_message( +size_t _olm_encode_group_message( uint8_t version, const uint8_t *session_id, size_t session_id_length, @@ -364,6 +364,7 @@ void _olm_encode_group_message( std::memcpy(session_id_pos, session_id, session_id_length); pos = encode(pos, GROUP_MESSAGE_INDEX_TAG, message_index); pos = encode(pos, GROUP_CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length); + return pos-output; } void _olm_decode_group_message( diff --git a/src/outbound_group_session.c b/src/outbound_group_session.c index 9f36ad8..9b2298a 100644 --- a/src/outbound_group_session.c +++ b/src/outbound_group_session.c @@ -199,51 +199,41 @@ size_t olm_group_encrypt_message_length( return _olm_encode_base64_length(message_length); } - -size_t olm_group_encrypt( - OlmOutboundGroupSession *session, - uint8_t const * plaintext, size_t plaintext_length, - uint8_t * message, size_t max_message_length +/** write an un-base64-ed message to the buffer */ +static size_t _encrypt( + OlmOutboundGroupSession *session, uint8_t const * plaintext, size_t plaintext_length, + uint8_t * buffer ) { - size_t ciphertext_length; - size_t rawmsglen; + size_t ciphertext_length, mac_length, message_length; size_t result; - uint8_t *ciphertext_ptr, *message_pos; - - rawmsglen = raw_message_length(session, plaintext_length); - - if (max_message_length < _olm_encode_base64_length(rawmsglen)) { - session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; - return (size_t)-1; - } + uint8_t *ciphertext_ptr; ciphertext_length = megolm_cipher->ops->encrypt_ciphertext_length( megolm_cipher, plaintext_length ); - /* we construct the message at the end of the buffer, so that - * we have room to base64-encode it once we're done. - */ - message_pos = message + _olm_encode_base64_length(rawmsglen) - rawmsglen; + mac_length = megolm_cipher->ops->mac_length(megolm_cipher); /* first we build the message structure, then we encrypt * the plaintext into it. */ - _olm_encode_group_message( + message_length = _olm_encode_group_message( OLM_PROTOCOL_VERSION, session->session_id, GROUP_SESSION_ID_LENGTH, session->ratchet.counter, ciphertext_length, - message_pos, + buffer, &ciphertext_ptr); + message_length += mac_length; + result = megolm_cipher->ops->encrypt( megolm_cipher, megolm_get_data(&(session->ratchet)), MEGOLM_RATCHET_LENGTH, plaintext, plaintext_length, ciphertext_ptr, ciphertext_length, - message_pos, rawmsglen + buffer, message_length ); if (result == (size_t)-1) { @@ -252,6 +242,37 @@ size_t olm_group_encrypt( megolm_advance(&(session->ratchet)); + return result; +} + +size_t olm_group_encrypt( + OlmOutboundGroupSession *session, + uint8_t const * plaintext, size_t plaintext_length, + uint8_t * message, size_t max_message_length +) { + size_t rawmsglen; + size_t result; + uint8_t *message_pos; + + rawmsglen = raw_message_length(session, plaintext_length); + + if (max_message_length < _olm_encode_base64_length(rawmsglen)) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + /* we construct the message at the end of the buffer, so that + * we have room to base64-encode it once we're done. + */ + message_pos = message + _olm_encode_base64_length(rawmsglen) - rawmsglen; + + /* write the message, and encrypt it, at message_pos */ + result = _encrypt(session, plaintext, plaintext_length, message_pos); + if (result == (size_t)-1) { + return result; + } + + /* bas64-encode it */ return _olm_encode_base64( message_pos, rawmsglen, message ); -- cgit v1.2.3 From f3c0dd76d73368a872d7acb2ffa293330f875089 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 24 May 2016 17:03:42 +0100 Subject: megolm.c: Remove spurious arguments to rehash_part These were left over from when rehash_part did a bunch of logging. --- src/megolm.c | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/megolm.c b/src/megolm.c index 110f939..efefc11 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -38,8 +38,7 @@ static uint8_t HASH_KEY_SEEDS[MEGOLM_RATCHET_PARTS][HASH_KEY_SEED_LENGTH] = { static void rehash_part( uint8_t data[MEGOLM_RATCHET_PARTS][MEGOLM_RATCHET_PART_LENGTH], - int rehash_from_part, int rehash_to_part, - uint32_t old_counter, uint32_t new_counter + int rehash_from_part, int rehash_to_part ) { _olm_crypto_hmac_sha256( data[rehash_from_part], @@ -96,7 +95,7 @@ void megolm_advance(Megolm *megolm) { /* now update R(h)...R(3) based on R(h) */ for (i = MEGOLM_RATCHET_PARTS-1; i >= h; i--) { - rehash_part(megolm->data, h, i, megolm->counter-1, megolm->counter); + rehash_part(megolm->data, h, i); } } @@ -122,7 +121,7 @@ void megolm_advance_to(Megolm *megolm, uint32_t advance_to) { * to R(j+1)...R(3). */ while (steps > 1) { - rehash_part(megolm->data, j, j, megolm->counter, next_counter); + rehash_part(megolm->data, j, j); megolm->counter = next_counter; steps --; next_counter = megolm->counter + increment; @@ -150,7 +149,7 @@ void megolm_advance_to(Megolm *megolm, uint32_t advance_to) { } while (k >= j) { - rehash_part(megolm->data, j, k, megolm->counter, next_counter); + rehash_part(megolm->data, j, k); k--; } megolm->counter = next_counter; -- cgit v1.2.3 From ef8d24f4839352963f8d0b53919016c35f492a22 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 24 May 2016 17:33:41 +0100 Subject: megolm.c: rewrite counter update We no longer need to keep track of intermediate values of the counter, which means we can update it much more easily. --- src/megolm.c | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/megolm.c b/src/megolm.c index efefc11..9e28f8d 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -105,26 +105,21 @@ void megolm_advance_to(Megolm *megolm, uint32_t advance_to) { /* starting with R0, see if we need to update each part of the hash */ for (j = 0; j < (int)MEGOLM_RATCHET_PARTS; j++) { int shift = (MEGOLM_RATCHET_PARTS-j-1) * 8; - uint32_t increment = 1 << shift; - uint32_t next_counter; + uint32_t mask = (~(uint32_t)0) << shift; /* how many times to we need to rehash this part? */ int steps = (advance_to >> shift) - (megolm->counter >> shift); + if (steps == 0) { continue; } - megolm->counter = megolm->counter & ~(increment - 1); - next_counter = megolm->counter + increment; - /* for all but the last step, we can just bump R(j) without regard * to R(j+1)...R(3). */ while (steps > 1) { rehash_part(megolm->data, j, j); - megolm->counter = next_counter; steps --; - next_counter = megolm->counter + increment; } /* on the last step (except for j=3), we need to bump at least R(j+1); @@ -152,6 +147,6 @@ void megolm_advance_to(Megolm *megolm, uint32_t advance_to) { rehash_part(megolm->data, j, k); k--; } - megolm->counter = next_counter; + megolm->counter = advance_to & mask; } } -- cgit v1.2.3 From 1f31427139acdef00f24aac13b67224fa915f9e7 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 24 May 2016 16:58:51 +0100 Subject: megolm_advance_to: Remove excessive optimisation There was some slightly overcomplex logic designed to save a couple of hash operations when R(0) and R(1) were advanced, but the extra code was hard to understand and didn't save much. --- src/megolm.c | 29 +++++++---------------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/src/megolm.c b/src/megolm.c index 9e28f8d..6d8af08 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -106,6 +106,7 @@ void megolm_advance_to(Megolm *megolm, uint32_t advance_to) { for (j = 0; j < (int)MEGOLM_RATCHET_PARTS; j++) { int shift = (MEGOLM_RATCHET_PARTS-j-1) * 8; uint32_t mask = (~(uint32_t)0) << shift; + int k; /* how many times to we need to rehash this part? */ int steps = (advance_to >> shift) - (megolm->counter >> shift); @@ -122,30 +123,14 @@ void megolm_advance_to(Megolm *megolm, uint32_t advance_to) { steps --; } - /* on the last step (except for j=3), we need to bump at least R(j+1); - * depending on the target count, we may also need to bump R(j+2) and - * R(j+3). + /* on the last step we also need to bump R(j+1)...R(3). + * + * (Theoretically, we could skip bumping R(j+2) if we're going to bump + * R(j+1) again, but the code to figure that out is a bit baroque and + * doesn't save us much). */ - int k; - switch(j) { - case 0: - if (!(advance_to & 0xFFFF00)) { k = 3; } - else if (!(advance_to & 0xFF00)) { k = 2; } - else { k = 1; } - break; - case 1: - if (!(advance_to & 0xFF00)) { k = 3; } - else { k = 2; } - break; - case 2: - case 3: - k = 3; - break; - } - - while (k >= j) { + for (k = 3; k >= j; k--) { rehash_part(megolm->data, j, k); - k--; } megolm->counter = advance_to & mask; } -- cgit v1.2.3 From 01ea3d4b9a3c6f3e0303c2d421a248715a96af99 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 24 May 2016 17:52:35 +0100 Subject: Fix handling of integer wraparound in megolm.c --- src/megolm.c | 8 ++++++-- tests/test_megolm.cpp | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/megolm.c b/src/megolm.c index 6d8af08..a969b36 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -108,8 +108,12 @@ void megolm_advance_to(Megolm *megolm, uint32_t advance_to) { uint32_t mask = (~(uint32_t)0) << shift; int k; - /* how many times to we need to rehash this part? */ - int steps = (advance_to >> shift) - (megolm->counter >> shift); + /* how many times do we need to rehash this part? + * + * '& 0xff' ensures we handle integer wraparound correctly + */ + unsigned int steps = + ((advance_to >> shift) - (megolm->counter >> shift)) & 0xff; if (steps == 0) { continue; diff --git a/tests/test_megolm.cpp b/tests/test_megolm.cpp index 871de36..bf53346 100644 --- a/tests/test_megolm.cpp +++ b/tests/test_megolm.cpp @@ -82,4 +82,20 @@ std::uint8_t random_bytes[] = assert_equals(expected3, megolm_get_data(&mr), MEGOLM_RATCHET_LENGTH); } +{ + TestCase test_case("Megolm::advance wraparound"); + + Megolm mr1, mr2; + + megolm_init(&mr1, random_bytes, 0xffffffffUL); + megolm_advance_to(&mr1, 0x1000000); + assert_equals(0x1000000U, mr1.counter); + + megolm_init(&mr2, random_bytes, 0); + megolm_advance_to(&mr2, 0x2000000); + assert_equals(0x2000000U, mr2.counter); + + assert_equals(megolm_get_data(&mr2), megolm_get_data(&mr1), MEGOLM_RATCHET_LENGTH); +} + } -- cgit v1.2.3 From 19a7fb5df5ec3445201ce5fbe475a08faf6319fc Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 25 May 2016 15:00:05 +0100 Subject: Fix an integer wrap around bug and add a couple more tests --- src/megolm.c | 6 +++++- tests/test_megolm.cpp | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/megolm.c b/src/megolm.c index a969b36..affd3cb 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -116,7 +116,11 @@ void megolm_advance_to(Megolm *megolm, uint32_t advance_to) { ((advance_to >> shift) - (megolm->counter >> shift)) & 0xff; if (steps == 0) { - continue; + if (advance_to < megolm->counter) { + steps = 0x100; + } else { + continue; + } } /* for all but the last step, we can just bump R(j) without regard diff --git a/tests/test_megolm.cpp b/tests/test_megolm.cpp index bf53346..3048fa3 100644 --- a/tests/test_megolm.cpp +++ b/tests/test_megolm.cpp @@ -98,4 +98,37 @@ std::uint8_t random_bytes[] = assert_equals(megolm_get_data(&mr2), megolm_get_data(&mr1), MEGOLM_RATCHET_LENGTH); } +{ + TestCase test_case("Megolm::advance overflow by one"); + + Megolm mr1, mr2; + + megolm_init(&mr1, random_bytes, 0xffffffffUL); + megolm_advance_to(&mr1, 0x0); + assert_equals(0x0U, mr1.counter); + + megolm_init(&mr2, random_bytes, 0xffffffffUL); + megolm_advance(&mr2); + assert_equals(0x0U, mr2.counter); + + assert_equals(megolm_get_data(&mr2), megolm_get_data(&mr1), MEGOLM_RATCHET_LENGTH); +} + +{ + TestCase test_case("Megolm::advance overflow"); + + Megolm mr1, mr2; + + megolm_init(&mr1, random_bytes, 0x1UL); + megolm_advance_to(&mr1, 0x80000000UL); + megolm_advance_to(&mr1, 0x0); + assert_equals(0x0U, mr1.counter); + + megolm_init(&mr2, random_bytes, 0x1UL); + megolm_advance_to(&mr2, 0x0UL); + assert_equals(0x0U, mr2.counter); + + assert_equals(megolm_get_data(&mr2), megolm_get_data(&mr1), MEGOLM_RATCHET_LENGTH); +} + } -- cgit v1.2.3 From fae8dacab5233c46f09e7d869afadaead2842609 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 25 May 2016 15:44:39 +0100 Subject: Add a comment explaining Mark's latest fix --- src/megolm.c | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/megolm.c b/src/megolm.c index affd3cb..3395449 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -116,6 +116,11 @@ void megolm_advance_to(Megolm *megolm, uint32_t advance_to) { ((advance_to >> shift) - (megolm->counter >> shift)) & 0xff; if (steps == 0) { + /* deal with the edge case where megolm->counter is slightly larger + * than advance_to. This should only happen for R(0), and implies + * that advance_to has wrapped around and we need to advance R(0) + * 256 times. + */ if (advance_to < megolm->counter) { steps = 0x100; } else { -- cgit v1.2.3