diff options
-rw-r--r-- | include/olm/ratchet.hh | 10 | ||||
-rw-r--r-- | src/ratchet.cpp | 36 | ||||
-rw-r--r-- | tests/test_olm_decrypt.cpp | 41 |
3 files changed, 37 insertions, 50 deletions
diff --git a/include/olm/ratchet.hh b/include/olm/ratchet.hh index acb5608..aaa3579 100644 --- a/include/olm/ratchet.hh +++ b/include/olm/ratchet.hh @@ -81,16 +81,6 @@ struct Ratchet { /** The last error that happened encrypting or decrypting a message. */ OlmErrorCode last_error; - /** - * A count of the number of times the root key has been advanced; this is - * maintained purely for diagnostics. - * - * If sender_chain is empty, this will be the index of the current receiver - * chain (odd for Alice, even for Bob); otherwise, the index of the current - * sender chain (even for Alice, odd for Bob). - */ - std::uint32_t chain_index; - /** The root key is used to generate chain keys from the ephemeral keys. * A new root_key derived each time a new chain is started. */ SharedKey root_key; diff --git a/src/ratchet.cpp b/src/ratchet.cpp index abcc8a1..dd1d42c 100644 --- a/src/ratchet.cpp +++ b/src/ratchet.cpp @@ -66,7 +66,6 @@ static void create_chain_key( static void advance_chain_key( - std::uint32_t chain_index, olm::ChainKey const & chain_key, olm::ChainKey & new_chain_key ) { @@ -80,7 +79,6 @@ static void advance_chain_key( static void create_message_keys( - std::uint32_t chain_index, olm::ChainKey const & chain_key, olm::KdfInfo const & info, olm::MessageKey & message_key) { @@ -111,7 +109,6 @@ static std::size_t verify_mac_and_decrypt( static std::size_t verify_mac_and_decrypt_for_existing_chain( olm::Ratchet const & session, - std::uint32_t chain_index, olm::ChainKey const & chain, olm::MessageReader const & reader, std::uint8_t * plaintext, std::size_t max_plaintext_length @@ -128,11 +125,11 @@ static std::size_t verify_mac_and_decrypt_for_existing_chain( olm::ChainKey new_chain = chain; while (new_chain.index < reader.counter) { - advance_chain_key(chain_index, new_chain, new_chain); + advance_chain_key(new_chain, new_chain); } olm::MessageKey message_key; - create_message_keys(chain_index, new_chain, session.kdf_info, message_key); + create_message_keys(new_chain, session.kdf_info, message_key); std::size_t result = verify_mac_and_decrypt( session.ratchet_cipher, message_key, reader, @@ -164,14 +161,13 @@ static std::size_t verify_mac_and_decrypt_for_new_chain( } olm::load_array(new_chain.ratchet_key.public_key, reader.ratchet_key); - std::uint32_t chain_index = session.chain_index + 1; create_chain_key( session.root_key, session.sender_chain[0].ratchet_key, new_chain.ratchet_key, session.kdf_info, new_root_key, new_chain.chain_key ); std::size_t result = verify_mac_and_decrypt_for_existing_chain( - session, chain_index, new_chain.chain_key, reader, + session, new_chain.chain_key, reader, plaintext, max_plaintext_length ); olm::unset(new_root_key); @@ -208,7 +204,6 @@ void olm::Ratchet::initialise_as_bob( pos = olm::load_array(root_key, pos); pos = olm::load_array(receiver_chains[0].chain_key.key, pos); receiver_chains[0].ratchet_key = their_ratchet_key; - chain_index = 0; olm::unset(derived_secrets); } @@ -230,7 +225,6 @@ void olm::Ratchet::initialise_as_alice( pos = olm::load_array(root_key, pos); pos = olm::load_array(sender_chain[0].chain_key.key, pos); sender_chain[0].ratchet_key = our_ratchet_key; - chain_index = 0; olm::unset(derived_secrets); } @@ -369,7 +363,6 @@ std::size_t olm::pickle_length( length += olm::pickle_length(value.sender_chain); length += olm::pickle_length(value.receiver_chains); length += olm::pickle_length(value.skipped_message_keys); - length += olm::pickle_length(value.chain_index); return length; } @@ -381,7 +374,6 @@ std::uint8_t * olm::pickle( pos = pickle(pos, value.sender_chain); pos = pickle(pos, value.receiver_chains); pos = pickle(pos, value.skipped_message_keys); - pos = pickle(pos, value.chain_index); return pos; } @@ -394,7 +386,6 @@ std::uint8_t const * olm::unpickle( pos = unpickle(pos, end, value.sender_chain); pos = unpickle(pos, end, value.receiver_chains); pos = unpickle(pos, end, value.skipped_message_keys); - pos = unpickle(pos, end, value.chain_index); return pos; } @@ -447,12 +438,11 @@ std::size_t olm::Ratchet::encrypt( kdf_info, root_key, sender_chain[0].chain_key ); - chain_index++; } MessageKey keys; - create_message_keys(chain_index, sender_chain[0].chain_key, kdf_info, keys); - advance_chain_key(chain_index, sender_chain[0].chain_key, sender_chain[0].chain_key); + create_message_keys(sender_chain[0].chain_key, kdf_info, keys); + advance_chain_key(sender_chain[0].chain_key, sender_chain[0].chain_key); std::size_t ciphertext_length = ratchet_cipher->ops->encrypt_ciphertext_length( ratchet_cipher, @@ -538,12 +528,6 @@ std::size_t olm::Ratchet::decrypt( } ReceiverChain * chain = nullptr; - auto receiver_chain_index = chain_index; - if (!sender_chain.empty()) { - // we've already advanced to the next (sender) chain; decrement to - // get back to the receiver chains - receiver_chain_index --; - } for (olm::ReceiverChain & receiver_chain : receiver_chains) { if (0 == std::memcmp( @@ -553,7 +537,6 @@ std::size_t olm::Ratchet::decrypt( chain = &receiver_chain; break; } - receiver_chain_index -= 2; } std::size_t result = std::size_t(-1); @@ -590,7 +573,7 @@ std::size_t olm::Ratchet::decrypt( } } else { result = verify_mac_and_decrypt_for_existing_chain( - *this, receiver_chain_index, chain->chain_key, + *this, chain->chain_key, reader, plaintext, max_plaintext_length ); } @@ -618,17 +601,16 @@ std::size_t olm::Ratchet::decrypt( olm::unset(sender_chain[0]); sender_chain.erase(sender_chain.begin()); - receiver_chain_index = ++chain_index; } while (chain->chain_key.index < reader.counter) { olm::SkippedMessageKey & key = *skipped_message_keys.insert(); - create_message_keys(receiver_chain_index, chain->chain_key, kdf_info, key.message_key); + create_message_keys(chain->chain_key, kdf_info, key.message_key); key.ratchet_key = chain->ratchet_key; - advance_chain_key(receiver_chain_index, chain->chain_key, chain->chain_key); + advance_chain_key(chain->chain_key, chain->chain_key); } - advance_chain_key(receiver_chain_index, chain->chain_key, chain->chain_key); + advance_chain_key(chain->chain_key, chain->chain_key); return result; } diff --git a/tests/test_olm_decrypt.cpp b/tests/test_olm_decrypt.cpp index 95cb18e..4a1fb97 100644 --- a/tests/test_olm_decrypt.cpp +++ b/tests/test_olm_decrypt.cpp @@ -1,11 +1,16 @@ #include "olm/olm.h" #include "unittest.hh" -const char * test_cases[] = { - "41776f", - "7fff6f0101346d671201", - "ee776f41496f674177804177778041776f6716670a677d6f670a67c2677d", - "e9e9c9c1e9e9c9e9c9c1e9e9c9c1", +struct test_case { + const char *msghex; + const char *expected_error; +}; + +const test_case test_cases[] = { + { "41776f", "BAD_MESSAGE_FORMAT" }, + { "7fff6f0101346d671201", "BAD_MESSAGE_FORMAT" }, + { "ee776f41496f674177804177778041776f6716670a677d6f670a67c2677d", "BAD_MESSAGE_FORMAT" }, + { "e9e9c9c1e9e9c9e9c9c1e9e9c9c1", "BAD_MESSAGE_FORMAT" }, }; @@ -31,29 +36,39 @@ void decode_hex( } } -void decrypt_case(int message_type, const char * test_case) { +void decrypt_case(int message_type, const test_case * test_case) { std::uint8_t session_memory[olm_session_size()]; ::OlmSession * session = ::olm_session(session_memory); std::uint8_t pickled[strlen(session_data)]; ::memcpy(pickled, session_data, sizeof(pickled)); - ::olm_unpickle_session(session, "", 0, pickled, sizeof(pickled)); + assert_not_equals( + ::olm_error(), + ::olm_unpickle_session(session, "", 0, pickled, sizeof(pickled)) + ); - std::size_t message_length = strlen(test_case) / 2; + std::size_t message_length = strlen(test_case->msghex) / 2; std::uint8_t * message = (std::uint8_t *) ::malloc(message_length); - decode_hex(test_case, message, message_length); + decode_hex(test_case->msghex, message, message_length); size_t max_length = olm_decrypt_max_plaintext_length( session, message_type, message, message_length ); - if (max_length == std::size_t(-1)) { + if (test_case->expected_error) { + assert_equals(::olm_error(), max_length); + assert_equals( + std::string(test_case->expected_error), + std::string(::olm_session_last_error(session)) + ); free(message); return; } + assert_not_equals(::olm_error(), max_length); + uint8_t plaintext[max_length]; - decode_hex(test_case, message, message_length); + decode_hex(test_case->msghex, message, message_length); olm_decrypt( session, message_type, message, message_length, @@ -67,8 +82,8 @@ int main() { { TestCase my_test("Olm decrypt test"); -for (unsigned int i = 0; i < sizeof(test_cases)/ sizeof(const char *); ++i) { - decrypt_case(0, test_cases[i]); +for (unsigned int i = 0; i < sizeof(test_cases)/ sizeof(test_cases[0]); ++i) { + decrypt_case(0, &test_cases[i]); } } |