aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2015-07-16 10:45:10 +0100
committerMark Haines <mark.haines@matrix.org>2015-07-16 10:45:10 +0100
commit89d9b972a6d629648d18f4227a08596c65c3894d (patch)
tree8859231d4d4d73bb56d00c055bac82f05ad34e48
parent7523b700cf5c465a484519aabe0b428c54cb91a0 (diff)
Add versions of olm_session_create_inbound and olm_session_matches_inbound which take the curve25519 identity key of the remote device we think the message is from as an additional argument
-rw-r--r--include/olm/olm.hh29
-rw-r--r--include/olm/session.hh2
-rw-r--r--src/olm.cpp61
-rw-r--r--src/session.cpp38
4 files changed, 120 insertions, 10 deletions
diff --git a/include/olm/olm.hh b/include/olm/olm.hh
index 64454e6..f08fb9f 100644
--- a/include/olm/olm.hh
+++ b/include/olm/olm.hh
@@ -242,6 +242,21 @@ size_t olm_create_inbound_session(
void * one_time_key_message, size_t message_length
);
+/** Create a new in-bound session for sending/receiving messages from an
+ * incoming PRE_KEY message. Returns olm_error() on failure. If the base64
+ * couldn't be decoded then olm_session_last_error will be "INVALID_BASE64".
+ * If the message was for an unsupported protocol version then
+ * olm_session_last_error() will be "BAD_MESSAGE_VERSION". If the message
+ * couldn't be decoded then then olm_session_last_error() will be
+ * "BAD_MESSAGE_FORMAT". If the message refers to an unknown one time
+ * key then olm_session_last_error() will be "BAD_MESSAGE_KEY_ID". */
+size_t olm_create_inbound_session_from(
+ OlmSession * session,
+ OlmAccount * account,
+ void const * their_identity_key, size_t their_identity_key_length,
+ void * one_time_key_message, size_t message_length
+);
+
/** Checks if the PRE_KEY message is for this in-bound session. This can happen
* if multiple messages are sent to this account before this account sends a
* message in reply. Returns olm_error() on failure. If the base64
@@ -255,6 +270,20 @@ size_t olm_matches_inbound_session(
void * one_time_key_message, size_t message_length
);
+/** Checks if the PRE_KEY message is for this in-bound session. This can happen
+ * if multiple messages are sent to this account before this account sends a
+ * message in reply. Returns olm_error() on failure. If the base64
+ * couldn't be decoded then olm_session_last_error will be "INVALID_BASE64".
+ * If the message was for an unsupported protocol version then
+ * olm_session_last_error() will be "BAD_MESSAGE_VERSION". If the message
+ * couldn't be decoded then then olm_session_last_error() will be
+ * "BAD_MESSAGE_FORMAT". */
+size_t olm_matches_inbound_session_from(
+ OlmSession * session,
+ void const * their_identity_key, size_t their_identity_key_length,
+ void * one_time_key_message, size_t message_length
+);
+
/** Removes the one time keys that the session used from the account. Returns
* olm_error() on failure. If the account doesn't have any matching one time
* keys then olm_account_last_error() will be "BAD_MESSAGE_KEY_ID". */
diff --git a/include/olm/session.hh b/include/olm/session.hh
index 125df68..b70ce6a 100644
--- a/include/olm/session.hh
+++ b/include/olm/session.hh
@@ -50,10 +50,12 @@ struct Session {
std::size_t new_inbound_session(
Account & local_account,
+ Curve25519PublicKey const * their_identity_key,
std::uint8_t const * one_time_key_message, std::size_t message_length
);
bool matches_inbound_session(
+ Curve25519PublicKey const * their_identity_key,
std::uint8_t const * one_time_key_message, std::size_t message_length
);
diff --git a/src/olm.cpp b/src/olm.cpp
index b121ec7..17461fe 100644
--- a/src/olm.cpp
+++ b/src/olm.cpp
@@ -518,7 +518,36 @@ size_t olm_create_inbound_session(
return std::size_t(-1);
}
return from_c(session)->new_inbound_session(
- *from_c(account), from_c(one_time_key_message), raw_length
+ *from_c(account), nullptr, from_c(one_time_key_message), raw_length
+ );
+}
+
+
+size_t olm_create_inbound_session_from(
+ OlmSession * session,
+ OlmAccount * account,
+ void const * their_identity_key, size_t their_identity_key_length,
+ void * one_time_key_message, size_t message_length
+) {
+ if (olm::decode_base64_length(their_identity_key_length) != 32) {
+ from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
+ return std::size_t(-1);
+ }
+ olm::Curve25519PublicKey identity_key;
+ olm::decode_base64(
+ from_c(their_identity_key), their_identity_key_length,
+ identity_key.public_key
+ );
+
+ std::size_t raw_length = b64_input(
+ from_c(one_time_key_message), message_length, from_c(session)->last_error
+ );
+ if (raw_length == std::size_t(-1)) {
+ return std::size_t(-1);
+ }
+ return from_c(session)->new_inbound_session(
+ *from_c(account), &identity_key,
+ from_c(one_time_key_message), raw_length
);
}
@@ -534,7 +563,35 @@ size_t olm_matches_inbound_session(
return std::size_t(-1);
}
bool matches = from_c(session)->matches_inbound_session(
- from_c(one_time_key_message), raw_length
+ nullptr, from_c(one_time_key_message), raw_length
+ );
+ return matches ? 1 : 0;
+}
+
+
+size_t olm_matches_inbound_session_from(
+ OlmSession * session,
+ void const * their_identity_key, size_t their_identity_key_length,
+ void * one_time_key_message, size_t message_length
+) {
+ if (olm::decode_base64_length(their_identity_key_length) != 32) {
+ from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
+ return std::size_t(-1);
+ }
+ olm::Curve25519PublicKey identity_key;
+ olm::decode_base64(
+ from_c(their_identity_key), their_identity_key_length,
+ identity_key.public_key
+ );
+
+ std::size_t raw_length = b64_input(
+ from_c(one_time_key_message), message_length, from_c(session)->last_error
+ );
+ if (raw_length == std::size_t(-1)) {
+ return std::size_t(-1);
+ }
+ bool matches = from_c(session)->matches_inbound_session(
+ &identity_key, from_c(one_time_key_message), raw_length
);
return matches ? 1 : 0;
}
diff --git a/src/session.cpp b/src/session.cpp
index 654cf1f..0249e6c 100644
--- a/src/session.cpp
+++ b/src/session.cpp
@@ -102,11 +102,13 @@ std::size_t olm::Session::new_outbound_session(
namespace {
bool check_message_fields(
- olm::PreKeyMessageReader & reader
+ olm::PreKeyMessageReader & reader, bool have_their_identity_key
) {
bool ok = true;
- ok = ok && reader.identity_key;
- ok = ok && reader.identity_key_length == KEY_LENGTH;
+ ok = ok && (have_their_identity_key || reader.identity_key);
+ if (reader.identity_key) {
+ ok = ok && reader.identity_key_length == KEY_LENGTH;
+ }
ok = ok && reader.message;
ok = ok && reader.base_key;
ok = ok && reader.base_key_length == KEY_LENGTH;
@@ -120,16 +122,27 @@ bool check_message_fields(
std::size_t olm::Session::new_inbound_session(
olm::Account & local_account,
+ olm::Curve25519PublicKey const * their_identity_key,
std::uint8_t const * one_time_key_message, std::size_t message_length
) {
olm::PreKeyMessageReader reader;
decode_one_time_key_message(reader, one_time_key_message, message_length);
- if (!check_message_fields(reader)) {
+ if (!check_message_fields(reader, their_identity_key)) {
last_error = olm::ErrorCode::BAD_MESSAGE_FORMAT;
return std::size_t(-1);
}
+ if (reader.identity_key && their_identity_key) {
+ bool same = 0 == std::memcmp(
+ their_identity_key->public_key, reader.identity_key, KEY_LENGTH
+ );
+ if (!same) {
+ last_error = olm::ErrorCode::BAD_MESSAGE_KEY_ID;
+ return std::size_t(-1);
+ }
+ }
+
olm::MessageReader message_reader;
decode_message(
message_reader, reader.message, reader.message_length,
@@ -177,19 +190,28 @@ std::size_t olm::Session::new_inbound_session(
bool olm::Session::matches_inbound_session(
+ olm::Curve25519PublicKey const * their_identity_key,
std::uint8_t const * one_time_key_message, std::size_t message_length
) {
olm::PreKeyMessageReader reader;
decode_one_time_key_message(reader, one_time_key_message, message_length);
- if (!check_message_fields(reader)) {
+ if (!check_message_fields(reader, their_identity_key)) {
return false;
}
bool same = true;
- same = same && 0 == std::memcmp(
- reader.identity_key, alice_identity_key.public_key, KEY_LENGTH
- );
+ if (reader.identity_key) {
+ same = same && 0 == std::memcmp(
+ reader.identity_key, alice_identity_key.public_key, KEY_LENGTH
+ );
+ }
+ if (their_identity_key) {
+ same = same && 0 == std::memcmp(
+ their_identity_key->public_key, alice_identity_key.public_key,
+ KEY_LENGTH
+ );
+ }
same = same && 0 == std::memcmp(
reader.base_key, alice_base_key.public_key, KEY_LENGTH
);