diff options
-rw-r--r-- | include/olm/account.hh | 46 | ||||
-rw-r--r-- | include/olm/base64.hh | 10 | ||||
-rwxr-xr-x | olm.py | 15 | ||||
-rw-r--r-- | src/account.cpp | 91 | ||||
-rw-r--r-- | src/base64.cpp | 16 | ||||
-rw-r--r-- | src/olm.cpp | 55 | ||||
-rw-r--r-- | src/session.cpp | 1 | ||||
-rw-r--r-- | tests/test_olm.cpp | 4 |
8 files changed, 127 insertions, 111 deletions
diff --git a/include/olm/account.hh b/include/olm/account.hh index 98b6b56..552f069 100644 --- a/include/olm/account.hh +++ b/include/olm/account.hh @@ -63,34 +63,36 @@ struct Account { /** Output the identity keys for this account as JSON in the following * format. * - * 14 "{\"algorithms\":" - * 30 "[\"m.olm.curve25519-aes-sha256\"" - * 15 "],\"device_id\":\"" + * 14 {"algorithms": + * 30 ["m.olm.curve25519-aes-sha256" + * 15 ],"device_id":" * ? <device identifier> - * 22 "\",\"keys\":{\"curve25519:" + * 22 ","keys":{"curve25519: * 4 <base64 characters> - * 3 "\":\"" + * 3 ":" * 43 <base64 characters> - * 11 "\",\"ed25519:" + * 11 ","ed25519: * 4 <base64 characters> - * 3 "\":\"" + * 3 ":" * 43 <base64 characters> - * 14 "\"},\"user_id\":\"" + * 14 "},"user_id":" * ? <user identifier> - * 19 "\",\"valid_after_ts\":" + * 19 ","valid_after_ts": * ? <digits> - * 18 ",\"valid_until_ts\":" + * 18 ,"valid_until_ts": * ? <digits> - * 16 ",\"signatures\":{\"" + * 16 ,"signatures":{" * ? <user identifier> - * 1 "/" + * 1 / * ? <device identifier> - * 12 "\":{\"ed25519:" + * 12 ":{"ed25519: * 4 <base64 characters> - * 3 "\":\"" + * 3 ":" * 86 <base64 characters> - * 4 "\"}}}" - */ + * 4 "}}} + * + * Returns the size of the JSON written or std::size_t(-1) on error. + * If the buffer is too small last_error will be OUTPUT_BUFFER_TOO_SMALL. */ std::size_t get_identity_json( std::uint8_t const * user_id, std::size_t user_id_length, std::uint8_t const * device_id, std::size_t device_id_length, @@ -99,6 +101,18 @@ struct Account { std::uint8_t * identity_json, std::size_t identity_json_length ); + /** Number of bytes needed to output the one time keys for this account */ + std::size_t get_one_time_keys_json_length(); + + /* + * Returns the size of the JSON written or std::size_t(-1) on error. + * If the buffer is too small last_error will be OUTPUT_BUFFER_TOO_SMALL. + */ + std::size_t get_one_time_keys_json( + std::uint8_t * one_time_json, std::size_t one_time_json_length + ); + + /** Lookup a one_time key with the given key-id */ OneTimeKey const * lookup_key( Curve25519PublicKey const & public_key ); diff --git a/include/olm/base64.hh b/include/olm/base64.hh index a68894d..018924a 100644 --- a/include/olm/base64.hh +++ b/include/olm/base64.hh @@ -21,12 +21,14 @@ namespace olm { -std::size_t encode_base64_length( +static std::size_t encode_base64_length( std::size_t input_length -); +) { + return 4 * ((input_length + 2) / 3) + (input_length + 2) % 3 - 2; +} -void encode_base64( +std::uint8_t * encode_base64( std::uint8_t const * input, std::size_t input_length, std::uint8_t * output ); @@ -37,7 +39,7 @@ std::size_t decode_base64_length( ); -void decode_base64( +std::uint8_t const * decode_base64( std::uint8_t const * input, std::size_t input_length, std::uint8_t * output ); @@ -297,8 +297,8 @@ if __name__ == '__main__': create_account.set_defaults(func=do_create_account) keys = commands.add_parser("keys", help="List public keys for an account") - keys.add_argument("--user-id", default="A User ID") - keys.add_argument("--device-id", default="A Device ID") + keys.add_argument("--user-id", default="@user:example.com") + keys.add_argument("--device-id", default="default_device_id") keys.add_argument("--valid-after", default=0, type=int) keys.add_argument("--valid-until", default=0, type=int) keys.add_argument("account_file", help="Local account file") @@ -311,18 +311,11 @@ if __name__ == '__main__': "device_keys": account.identity_keys( args.user_id, args.device_id, args.valid_after, args.valid_until, - ) - } - ot_keys = account.one_time_keys() - result2 = { - "one_time_keys": [{ - "keyId": k[0], - "publicKey": str(k[1]), - } for k in ot_keys[1:]] + ), + "one_time_keys": account.one_time_keys(), } try: yaml.safe_dump(result1, sys.stdout, default_flow_style=False) - yaml.dump(result2, sys.stdout, default_flow_style=False) except: pass diff --git a/src/account.cpp b/src/account.cpp index 297d2f4..a171f5c 100644 --- a/src/account.cpp +++ b/src/account.cpp @@ -143,13 +143,17 @@ std::size_t olm::Account::get_identity_json_length( length += sizeof(IDENTITY_JSON_PART_0) - 1; length += device_id_length; length += sizeof(IDENTITY_JSON_PART_1) - 1; - length += 4; + length += olm::encode_base64_length(3); length += sizeof(IDENTITY_JSON_PART_2) - 1; - length += 43; + length += olm::encode_base64_length( + sizeof(identity_keys.curve25519_key.public_key) + ); length += sizeof(IDENTITY_JSON_PART_3) - 1; - length += 4; + length += olm::encode_base64_length(3); length += sizeof(IDENTITY_JSON_PART_4) - 1; - length += 43; + length += olm::encode_base64_length( + sizeof(identity_keys.ed25519_key.public_key) + ); length += sizeof(IDENTITY_JSON_PART_5) - 1; length += user_id_length; length += sizeof(IDENTITY_JSON_PART_6) - 1; @@ -161,9 +165,9 @@ std::size_t olm::Account::get_identity_json_length( length += sizeof(IDENTITY_JSON_PART_9) - 1; length += device_id_length; length += sizeof(IDENTITY_JSON_PART_A) - 1; - length += 4; + length += olm::encode_base64_length(3); length += sizeof(IDENTITY_JSON_PART_B) - 1; - length += 86; + length += olm::encode_base64_length(64); length += sizeof(IDENTITY_JSON_PART_C) - 1; return length; } @@ -176,7 +180,6 @@ std::size_t olm::Account::get_identity_json( std::uint64_t valid_until_ts, std::uint8_t * identity_json, std::size_t identity_json_length ) { - std::uint8_t * pos = identity_json; std::uint8_t signature[64]; size_t expected_length = get_identity_json_length( @@ -191,17 +194,13 @@ std::size_t olm::Account::get_identity_json( pos = write_string(pos, IDENTITY_JSON_PART_0); pos = write_string(pos, device_id, device_id_length); pos = write_string(pos, IDENTITY_JSON_PART_1); - encode_base64(identity_keys.curve25519_key.public_key, 3, pos); - pos += 4; + pos = encode_base64(identity_keys.curve25519_key.public_key, 3, pos); pos = write_string(pos, IDENTITY_JSON_PART_2); - encode_base64(identity_keys.curve25519_key.public_key, 32, pos); - pos += 43; + pos = encode_base64(identity_keys.curve25519_key.public_key, 32, pos); pos = write_string(pos, IDENTITY_JSON_PART_3); - encode_base64(identity_keys.ed25519_key.public_key, 3, pos); - pos += 4; + pos = encode_base64(identity_keys.ed25519_key.public_key, 3, pos); pos = write_string(pos, IDENTITY_JSON_PART_4); - encode_base64(identity_keys.ed25519_key.public_key, 32, pos); - pos += 43; + pos = encode_base64(identity_keys.ed25519_key.public_key, 32, pos); pos = write_string(pos, IDENTITY_JSON_PART_5); pos = write_string(pos, user_id, user_id_length); pos = write_string(pos, IDENTITY_JSON_PART_6); @@ -221,15 +220,69 @@ std::size_t olm::Account::get_identity_json( pos = write_string(pos, IDENTITY_JSON_PART_9); pos = write_string(pos, device_id, device_id_length); pos = write_string(pos, IDENTITY_JSON_PART_A); - encode_base64(identity_keys.ed25519_key.public_key, 3, pos); - pos += 4; + pos = encode_base64(identity_keys.ed25519_key.public_key, 3, pos); pos = write_string(pos, IDENTITY_JSON_PART_B); - encode_base64(signature, 64, pos); - pos += 86; + pos = encode_base64(signature, 64, pos); pos = write_string(pos, IDENTITY_JSON_PART_C); return pos - identity_json; } +namespace { +uint8_t ONE_TIME_KEY_JSON_ALG[] = "curve25519"; +} + +std::size_t olm::Account::get_one_time_keys_json_length( +) { + std::size_t length = 0; + for (auto const & key : one_time_keys) { + length += 2; /* {" */ + length += sizeof(ONE_TIME_KEY_JSON_ALG) - 1; + length += 1; /* : */ + length += olm::encode_base64_length(olm::pickle_length(key.id)); + length += 3; /* ":" */ + length += olm::encode_base64_length(sizeof(key.key.public_key)); + length += 1; /* " */ + } + if (length) { + return length + 1; /* } */ + } else { + return 2; /* {} */ + } +} + + +std::size_t olm::Account::get_one_time_keys_json( + std::uint8_t * one_time_json, std::size_t one_time_json_length +) { + std::uint8_t * pos = one_time_json; + if (one_time_json_length < get_one_time_keys_json_length()) { + last_error = olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + std::uint8_t sep = '{'; + for (auto const & key : one_time_keys) { + *(pos++) = sep; + *(pos++) = '\"'; + pos = write_string(pos, ONE_TIME_KEY_JSON_ALG); + *(pos++) = ':'; + std::uint8_t key_id[olm::pickle_length(key.id)]; + olm::pickle(key_id, key.id); + pos = olm::encode_base64(key_id, sizeof(key_id), pos); + *(pos++) = '\"'; *(pos++) = ':'; *(pos++) = '\"'; + pos = olm::encode_base64( + key.key.public_key, sizeof(key.key.public_key), pos + ); + *(pos++) = '\"'; + sep = ','; + } + if (sep != ',') { + *(pos++) = sep; + } + *(pos++) = '}'; + return pos - one_time_json; +} + + namespace olm { static std::size_t pickle_length( diff --git a/src/base64.cpp b/src/base64.cpp index a8631a1..bf8492e 100644 --- a/src/base64.cpp +++ b/src/base64.cpp @@ -45,14 +45,7 @@ static const std::uint8_t DECODE_BASE64[128] = { } // namespace -std::size_t olm::encode_base64_length( - std::size_t input_length -) { - return 4 * ((input_length + 2) / 3) + (input_length + 2) % 3 - 2; -} - - -void olm::encode_base64( +std::uint8_t * olm::encode_base64( std::uint8_t const * input, std::size_t input_length, std::uint8_t * output ) { @@ -70,6 +63,7 @@ void olm::encode_base64( output += 4; } unsigned remainder = input + input_length - pos; + std::uint8_t * result = output; if (remainder) { unsigned value = pos[0]; if (remainder == 2) { @@ -77,13 +71,16 @@ void olm::encode_base64( value <<= 2; output[2] = ENCODE_BASE64[value & 0x3F]; value >>= 6; + result += 3; } else { value <<= 4; + result += 2; } output[1] = ENCODE_BASE64[value & 0x3F]; value >>= 6; output[0] = ENCODE_BASE64[value]; } + return result; } @@ -98,7 +95,7 @@ std::size_t olm::decode_base64_length( } -void olm::decode_base64( +std::uint8_t const * olm::decode_base64( std::uint8_t const * input, std::size_t input_length, std::uint8_t * output ) { @@ -129,4 +126,5 @@ void olm::decode_base64( } output[0] = value; } + return input + input_length; } diff --git a/src/olm.cpp b/src/olm.cpp index 8106917..ede9c26 100644 --- a/src/olm.cpp +++ b/src/olm.cpp @@ -320,38 +320,6 @@ size_t olm_create_account( return from_c(account)->new_account(from_c(random), random_length); } -namespace { - -static const std::size_t OUTPUT_KEY_LENGTH = 2 + 10 + 2 + - olm::encode_base64_length(32) + 3; - -void output_one_time_key( - olm::OneTimeKey const & key, - std::uint8_t sep, - std::uint8_t * output -) { - output[0] = sep; - output[1] = '['; - std::memset(output + 2, ' ', 10); - uint32_t value = key.id; - uint8_t * number = output + 11; - *number = '0' + value % 10; - value /= 10; - while (value) { - *(--number) = '0' + value % 10; - value /= 10; - } - output[12] = ','; - output[13] = '"'; - olm::encode_base64(key.key.public_key, 32, output + 14); - output[OUTPUT_KEY_LENGTH - 3] = '"'; - output[OUTPUT_KEY_LENGTH - 2] = ']'; - output[OUTPUT_KEY_LENGTH - 1] = '\n'; -} - -} // namespace - - size_t olm_account_identity_keys_length( OlmAccount * account, size_t user_id_length, @@ -388,30 +356,17 @@ size_t olm_account_identity_keys( size_t olm_account_one_time_keys_length( OlmAccount * account ) { - size_t count = from_c(account)->one_time_keys.size(); - return OUTPUT_KEY_LENGTH * count + 1; + return from_c(account)->get_one_time_keys_json_length(); } size_t olm_account_one_time_keys( OlmAccount * account, - void * identity_keys, size_t identity_key_length + void * one_time_keys_json, size_t one_time_key_json_length ) { - std::size_t length = olm_account_one_time_keys_length(account); - if (identity_key_length < length) { - from_c(account)->last_error = - olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL; - return size_t(-1); - } - std::uint8_t * output = from_c(identity_keys); - std::uint8_t sep = '['; - for (auto const & key : from_c(account)->one_time_keys) { - output_one_time_key(key, sep, output); - output += OUTPUT_KEY_LENGTH; - sep = ','; - } - output[0] = ']'; - return length; + return from_c(account)->get_one_time_keys_json( + from_c(one_time_keys_json), one_time_key_json_length + ); } diff --git a/src/session.cpp b/src/session.cpp index a56725d..f3b7637 100644 --- a/src/session.cpp +++ b/src/session.cpp @@ -157,6 +157,7 @@ std::size_t olm::Session::new_inbound_session( last_error = olm::ErrorCode::BAD_MESSAGE_KEY_ID; return std::size_t(-1); } + bob_one_time_key_id = our_one_time_key->id; std::uint8_t shared_secret[96]; diff --git a/tests/test_olm.cpp b/tests/test_olm.cpp index 7c3ec87..bedabf3 100644 --- a/tests/test_olm.cpp +++ b/tests/test_olm.cpp @@ -89,7 +89,7 @@ mock_random_a(a_rand, sizeof(a_rand)); assert_not_equals(std::size_t(-1), ::olm_create_outbound_session( a_session, a_account, b_id_keys + 88, 43, - b_ot_keys + 74, 43, + b_ot_keys + 22, 43, a_rand, sizeof(a_rand) )); @@ -193,7 +193,7 @@ mock_random_a(a_rand, sizeof(a_rand)); assert_not_equals(std::size_t(-1), ::olm_create_outbound_session( a_session, a_account, b_id_keys + 88, 43, - b_ot_keys + 74, 43, + b_ot_keys + 22, 43, a_rand, sizeof(a_rand) )); |