From a073d12d8367d27db97751d46b766e8480fd39e4 Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <richard@matrix.org>
Date: Wed, 18 May 2016 18:03:59 +0100
Subject: Support for pickling inbound group sessions

---
 src/inbound_group_session.c  | 76 ++++++++++++++++++++++++++++++++++++++++++++
 tests/test_group_session.cpp | 33 ++++++++++++++++++-
 2 files changed, 108 insertions(+), 1 deletion(-)

diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c
index 4796414..34908a9 100644
--- a/src/inbound_group_session.c
+++ b/src/inbound_group_session.c
@@ -22,8 +22,12 @@
 #include "olm/error.h"
 #include "olm/megolm.h"
 #include "olm/message.h"
+#include "olm/pickle.h"
+#include "olm/pickle_encoding.h"
+
 
 #define OLM_PROTOCOL_VERSION     3
+#define PICKLE_VERSION           1
 
 struct OlmInboundGroupSession {
     /** our earliest known ratchet value */
@@ -86,6 +90,78 @@ size_t olm_init_inbound_group_session(
     return 0;
 }
 
+static size_t raw_pickle_length(
+    const OlmInboundGroupSession *session
+) {
+    size_t length = 0;
+    length += _olm_pickle_uint32_length(PICKLE_VERSION);
+    length += megolm_pickle_length(&session->initial_ratchet);
+    length += megolm_pickle_length(&session->latest_ratchet);
+    return length;
+}
+
+size_t olm_pickle_inbound_group_session_length(
+    const OlmInboundGroupSession *session
+) {
+    return _olm_enc_output_length(raw_pickle_length(session));
+}
+
+size_t olm_pickle_inbound_group_session(
+    OlmInboundGroupSession *session,
+    void const * key, size_t key_length,
+    void * pickled, size_t pickled_length
+) {
+    size_t raw_length = raw_pickle_length(session);
+    uint8_t *pos;
+
+    if (pickled_length < _olm_enc_output_length(raw_length)) {
+        session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+        return (size_t)-1;
+    }
+
+    pos = _olm_enc_output_pos(pickled, raw_length);
+    pos = _olm_pickle_uint32(pos, PICKLE_VERSION);
+    pos = megolm_pickle(&session->initial_ratchet, pos);
+    pos = megolm_pickle(&session->latest_ratchet, pos);
+
+    return _olm_enc_output(key, key_length, pickled, raw_length);
+}
+
+size_t olm_unpickle_inbound_group_session(
+    OlmInboundGroupSession *session,
+    void const * key, size_t key_length,
+    void * pickled, size_t pickled_length
+) {
+    const uint8_t *pos;
+    const uint8_t *end;
+    uint32_t pickle_version;
+
+    size_t raw_length = _olm_enc_input(
+        key, key_length, pickled, pickled_length, &(session->last_error)
+    );
+    if (raw_length == (size_t)-1) {
+        return raw_length;
+    }
+
+    pos = pickled;
+    end = pos + raw_length;
+    pos = _olm_unpickle_uint32(pos, end, &pickle_version);
+    if (pickle_version != PICKLE_VERSION) {
+        session->last_error = OLM_UNKNOWN_PICKLE_VERSION;
+        return (size_t)-1;
+    }
+    pos = megolm_unpickle(&session->initial_ratchet, pos, end);
+    pos = megolm_unpickle(&session->latest_ratchet, pos, end);
+
+    if (end != pos) {
+        /* We had the wrong number of bytes in the input. */
+        session->last_error = OLM_CORRUPTED_PICKLE;
+        return (size_t)-1;
+    }
+
+    return pickled_length;
+}
+
 size_t olm_group_decrypt_max_plaintext_length(
     OlmInboundGroupSession *session,
     uint8_t * message, size_t message_length
diff --git a/tests/test_group_session.cpp b/tests/test_group_session.cpp
index 5bbdc9d..4a82154 100644
--- a/tests/test_group_session.cpp
+++ b/tests/test_group_session.cpp
@@ -20,7 +20,7 @@
 int main() {
 
 {
-    TestCase test_case("Pickle outbound group");
+    TestCase test_case("Pickle outbound group session");
 
     size_t size = olm_outbound_group_session_size();
     uint8_t memory[size];
@@ -50,6 +50,37 @@ int main() {
 }
 
 
+{
+    TestCase test_case("Pickle inbound group session");
+
+    size_t size = olm_inbound_group_session_size();
+    uint8_t memory[size];
+    OlmInboundGroupSession *session = olm_inbound_group_session(memory);
+
+    size_t pickle_length = olm_pickle_inbound_group_session_length(session);
+    uint8_t pickle1[pickle_length];
+    olm_pickle_inbound_group_session(session,
+                                     "secret_key", 10,
+                                     pickle1, pickle_length);
+    uint8_t pickle2[pickle_length];
+    memcpy(pickle2, pickle1, pickle_length);
+
+    uint8_t buffer2[size];
+    OlmInboundGroupSession *session2 = olm_inbound_group_session(buffer2);
+    size_t res = olm_unpickle_inbound_group_session(session2,
+                                                    "secret_key", 10,
+                                                    pickle2, pickle_length);
+    assert_not_equals((size_t)-1, res);
+    assert_equals(pickle_length,
+                  olm_pickle_inbound_group_session_length(session2));
+    olm_pickle_inbound_group_session(session2,
+                                      "secret_key", 10,
+                                      pickle2, pickle_length);
+
+    assert_equals(pickle1, pickle2, pickle_length);
+}
+
+
 {
     TestCase test_case("Group message send/receive");
 
-- 
cgit v1.2.3-70-g09d2