aboutsummaryrefslogtreecommitdiff
path: root/src/session.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/session.cpp')
-rw-r--r--src/session.cpp38
1 files changed, 30 insertions, 8 deletions
diff --git a/src/session.cpp b/src/session.cpp
index 654cf1f..0249e6c 100644
--- a/src/session.cpp
+++ b/src/session.cpp
@@ -102,11 +102,13 @@ std::size_t olm::Session::new_outbound_session(
namespace {
bool check_message_fields(
- olm::PreKeyMessageReader & reader
+ olm::PreKeyMessageReader & reader, bool have_their_identity_key
) {
bool ok = true;
- ok = ok && reader.identity_key;
- ok = ok && reader.identity_key_length == KEY_LENGTH;
+ ok = ok && (have_their_identity_key || reader.identity_key);
+ if (reader.identity_key) {
+ ok = ok && reader.identity_key_length == KEY_LENGTH;
+ }
ok = ok && reader.message;
ok = ok && reader.base_key;
ok = ok && reader.base_key_length == KEY_LENGTH;
@@ -120,16 +122,27 @@ bool check_message_fields(
std::size_t olm::Session::new_inbound_session(
olm::Account & local_account,
+ olm::Curve25519PublicKey const * their_identity_key,
std::uint8_t const * one_time_key_message, std::size_t message_length
) {
olm::PreKeyMessageReader reader;
decode_one_time_key_message(reader, one_time_key_message, message_length);
- if (!check_message_fields(reader)) {
+ if (!check_message_fields(reader, their_identity_key)) {
last_error = olm::ErrorCode::BAD_MESSAGE_FORMAT;
return std::size_t(-1);
}
+ if (reader.identity_key && their_identity_key) {
+ bool same = 0 == std::memcmp(
+ their_identity_key->public_key, reader.identity_key, KEY_LENGTH
+ );
+ if (!same) {
+ last_error = olm::ErrorCode::BAD_MESSAGE_KEY_ID;
+ return std::size_t(-1);
+ }
+ }
+
olm::MessageReader message_reader;
decode_message(
message_reader, reader.message, reader.message_length,
@@ -177,19 +190,28 @@ std::size_t olm::Session::new_inbound_session(
bool olm::Session::matches_inbound_session(
+ olm::Curve25519PublicKey const * their_identity_key,
std::uint8_t const * one_time_key_message, std::size_t message_length
) {
olm::PreKeyMessageReader reader;
decode_one_time_key_message(reader, one_time_key_message, message_length);
- if (!check_message_fields(reader)) {
+ if (!check_message_fields(reader, their_identity_key)) {
return false;
}
bool same = true;
- same = same && 0 == std::memcmp(
- reader.identity_key, alice_identity_key.public_key, KEY_LENGTH
- );
+ if (reader.identity_key) {
+ same = same && 0 == std::memcmp(
+ reader.identity_key, alice_identity_key.public_key, KEY_LENGTH
+ );
+ }
+ if (their_identity_key) {
+ same = same && 0 == std::memcmp(
+ their_identity_key->public_key, alice_identity_key.public_key,
+ KEY_LENGTH
+ );
+ }
same = same && 0 == std::memcmp(
reader.base_key, alice_base_key.public_key, KEY_LENGTH
);