From 89d9b972a6d629648d18f4227a08596c65c3894d Mon Sep 17 00:00:00 2001
From: Mark Haines <mark.haines@matrix.org>
Date: Thu, 16 Jul 2015 10:45:10 +0100
Subject: Add versions of olm_session_create_inbound and
 olm_session_matches_inbound which take the curve25519 identity key of the
 remote device we think the message is from as an additional argument

---
 src/olm.cpp     | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++++++--
 src/session.cpp | 38 +++++++++++++++++++++++++++--------
 2 files changed, 89 insertions(+), 10 deletions(-)

(limited to 'src')

diff --git a/src/olm.cpp b/src/olm.cpp
index b121ec7..17461fe 100644
--- a/src/olm.cpp
+++ b/src/olm.cpp
@@ -518,7 +518,36 @@ size_t olm_create_inbound_session(
         return std::size_t(-1);
     }
     return from_c(session)->new_inbound_session(
-        *from_c(account), from_c(one_time_key_message), raw_length
+        *from_c(account), nullptr, from_c(one_time_key_message), raw_length
+    );
+}
+
+
+size_t olm_create_inbound_session_from(
+    OlmSession * session,
+    OlmAccount * account,
+    void const * their_identity_key, size_t their_identity_key_length,
+    void * one_time_key_message, size_t message_length
+) {
+    if (olm::decode_base64_length(their_identity_key_length) != 32) {
+        from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
+        return std::size_t(-1);
+    }
+    olm::Curve25519PublicKey identity_key;
+    olm::decode_base64(
+        from_c(their_identity_key), their_identity_key_length,
+        identity_key.public_key
+    );
+
+    std::size_t raw_length = b64_input(
+        from_c(one_time_key_message), message_length, from_c(session)->last_error
+    );
+    if (raw_length == std::size_t(-1)) {
+        return std::size_t(-1);
+    }
+    return from_c(session)->new_inbound_session(
+        *from_c(account), &identity_key,
+        from_c(one_time_key_message), raw_length
     );
 }
 
@@ -534,7 +563,35 @@ size_t olm_matches_inbound_session(
         return std::size_t(-1);
     }
     bool matches = from_c(session)->matches_inbound_session(
-        from_c(one_time_key_message), raw_length
+        nullptr, from_c(one_time_key_message), raw_length
+    );
+    return matches ? 1 : 0;
+}
+
+
+size_t olm_matches_inbound_session_from(
+    OlmSession * session,
+    void const * their_identity_key, size_t their_identity_key_length,
+    void * one_time_key_message, size_t message_length
+) {
+    if (olm::decode_base64_length(their_identity_key_length) != 32) {
+        from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
+        return std::size_t(-1);
+    }
+    olm::Curve25519PublicKey identity_key;
+    olm::decode_base64(
+        from_c(their_identity_key), their_identity_key_length,
+        identity_key.public_key
+    );
+
+    std::size_t raw_length = b64_input(
+        from_c(one_time_key_message), message_length, from_c(session)->last_error
+    );
+    if (raw_length == std::size_t(-1)) {
+        return std::size_t(-1);
+    }
+    bool matches = from_c(session)->matches_inbound_session(
+        &identity_key, from_c(one_time_key_message), raw_length
     );
     return matches ? 1 : 0;
 }
diff --git a/src/session.cpp b/src/session.cpp
index 654cf1f..0249e6c 100644
--- a/src/session.cpp
+++ b/src/session.cpp
@@ -102,11 +102,13 @@ std::size_t olm::Session::new_outbound_session(
 namespace {
 
 bool check_message_fields(
-    olm::PreKeyMessageReader & reader
+    olm::PreKeyMessageReader & reader, bool have_their_identity_key
 ) {
     bool ok = true;
-    ok = ok && reader.identity_key;
-    ok = ok && reader.identity_key_length == KEY_LENGTH;
+    ok = ok && (have_their_identity_key || reader.identity_key);
+    if (reader.identity_key) {
+        ok = ok && reader.identity_key_length == KEY_LENGTH;
+    }
     ok = ok && reader.message;
     ok = ok && reader.base_key;
     ok = ok && reader.base_key_length == KEY_LENGTH;
@@ -120,16 +122,27 @@ bool check_message_fields(
 
 std::size_t olm::Session::new_inbound_session(
     olm::Account & local_account,
+    olm::Curve25519PublicKey const * their_identity_key,
     std::uint8_t const * one_time_key_message, std::size_t message_length
 ) {
     olm::PreKeyMessageReader reader;
     decode_one_time_key_message(reader, one_time_key_message, message_length);
 
-    if (!check_message_fields(reader)) {
+    if (!check_message_fields(reader, their_identity_key)) {
         last_error = olm::ErrorCode::BAD_MESSAGE_FORMAT;
         return std::size_t(-1);
     }
 
+    if (reader.identity_key && their_identity_key) {
+        bool same = 0 == std::memcmp(
+            their_identity_key->public_key, reader.identity_key, KEY_LENGTH
+        );
+        if (!same) {
+            last_error = olm::ErrorCode::BAD_MESSAGE_KEY_ID;
+            return std::size_t(-1);
+        }
+    }
+
     olm::MessageReader message_reader;
     decode_message(
         message_reader, reader.message, reader.message_length,
@@ -177,19 +190,28 @@ std::size_t olm::Session::new_inbound_session(
 
 
 bool olm::Session::matches_inbound_session(
+    olm::Curve25519PublicKey const * their_identity_key,
     std::uint8_t const * one_time_key_message, std::size_t message_length
 ) {
     olm::PreKeyMessageReader reader;
     decode_one_time_key_message(reader, one_time_key_message, message_length);
 
-    if (!check_message_fields(reader)) {
+    if (!check_message_fields(reader, their_identity_key)) {
         return false;
     }
 
     bool same = true;
-    same = same && 0 == std::memcmp(
-        reader.identity_key, alice_identity_key.public_key, KEY_LENGTH
-    );
+    if (reader.identity_key) {
+        same = same && 0 == std::memcmp(
+            reader.identity_key, alice_identity_key.public_key, KEY_LENGTH
+        );
+    }
+    if (their_identity_key) {
+        same = same && 0 == std::memcmp(
+            their_identity_key->public_key, alice_identity_key.public_key,
+            KEY_LENGTH
+        );
+    }
     same = same && 0 == std::memcmp(
         reader.base_key, alice_base_key.public_key, KEY_LENGTH
     );
-- 
cgit v1.2.3-70-g09d2