diff options
Diffstat (limited to 'src/session.cpp')
-rw-r--r-- | src/session.cpp | 38 |
1 files changed, 30 insertions, 8 deletions
diff --git a/src/session.cpp b/src/session.cpp index 654cf1f..0249e6c 100644 --- a/src/session.cpp +++ b/src/session.cpp @@ -102,11 +102,13 @@ std::size_t olm::Session::new_outbound_session( namespace { bool check_message_fields( - olm::PreKeyMessageReader & reader + olm::PreKeyMessageReader & reader, bool have_their_identity_key ) { bool ok = true; - ok = ok && reader.identity_key; - ok = ok && reader.identity_key_length == KEY_LENGTH; + ok = ok && (have_their_identity_key || reader.identity_key); + if (reader.identity_key) { + ok = ok && reader.identity_key_length == KEY_LENGTH; + } ok = ok && reader.message; ok = ok && reader.base_key; ok = ok && reader.base_key_length == KEY_LENGTH; @@ -120,16 +122,27 @@ bool check_message_fields( std::size_t olm::Session::new_inbound_session( olm::Account & local_account, + olm::Curve25519PublicKey const * their_identity_key, std::uint8_t const * one_time_key_message, std::size_t message_length ) { olm::PreKeyMessageReader reader; decode_one_time_key_message(reader, one_time_key_message, message_length); - if (!check_message_fields(reader)) { + if (!check_message_fields(reader, their_identity_key)) { last_error = olm::ErrorCode::BAD_MESSAGE_FORMAT; return std::size_t(-1); } + if (reader.identity_key && their_identity_key) { + bool same = 0 == std::memcmp( + their_identity_key->public_key, reader.identity_key, KEY_LENGTH + ); + if (!same) { + last_error = olm::ErrorCode::BAD_MESSAGE_KEY_ID; + return std::size_t(-1); + } + } + olm::MessageReader message_reader; decode_message( message_reader, reader.message, reader.message_length, @@ -177,19 +190,28 @@ std::size_t olm::Session::new_inbound_session( bool olm::Session::matches_inbound_session( + olm::Curve25519PublicKey const * their_identity_key, std::uint8_t const * one_time_key_message, std::size_t message_length ) { olm::PreKeyMessageReader reader; decode_one_time_key_message(reader, one_time_key_message, message_length); - if (!check_message_fields(reader)) { + if (!check_message_fields(reader, their_identity_key)) { return false; } bool same = true; - same = same && 0 == std::memcmp( - reader.identity_key, alice_identity_key.public_key, KEY_LENGTH - ); + if (reader.identity_key) { + same = same && 0 == std::memcmp( + reader.identity_key, alice_identity_key.public_key, KEY_LENGTH + ); + } + if (their_identity_key) { + same = same && 0 == std::memcmp( + their_identity_key->public_key, alice_identity_key.public_key, + KEY_LENGTH + ); + } same = same && 0 == std::memcmp( reader.base_key, alice_base_key.public_key, KEY_LENGTH ); |