diff options
-rw-r--r-- | include/odhtdb/sql/Sql.hpp | 11 | ||||
-rw-r--r-- | src/DatabaseStorage.cpp | 89 | ||||
-rw-r--r-- | src/sql/Sql.cpp | 20 | ||||
-rw-r--r-- | src/sql/SqlExec.cpp | 4 | ||||
-rw-r--r-- | src/sql/SqlQuery.cpp | 5 |
5 files changed, 67 insertions, 62 deletions
diff --git a/include/odhtdb/sql/Sql.hpp b/include/odhtdb/sql/Sql.hpp index 6c78360..cb740a4 100644 --- a/include/odhtdb/sql/Sql.hpp +++ b/include/odhtdb/sql/Sql.hpp @@ -34,4 +34,15 @@ namespace odhtdb }; const Type type; }; + + class SqlTransaction + { + public: + SqlTransaction(sqlite3 *db); + ~SqlTransaction(); + + void commit(); + private: + sqlite3 *db; + }; } diff --git a/src/DatabaseStorage.cpp b/src/DatabaseStorage.cpp index 8bef77e..129ba13 100644 --- a/src/DatabaseStorage.cpp +++ b/src/DatabaseStorage.cpp @@ -344,7 +344,7 @@ namespace odhtdb assert(deserializer.empty()); } - static void sqlite_step_rollback_on_failure(sqlite3 *db, sqlite3_stmt *stmt, const char *description) + static void sqlite_step_throw_on_failure(sqlite3 *db, sqlite3_stmt *stmt, const char *description) { int rc = sqlite3_step(stmt); if(rc != SQLITE_DONE) @@ -352,7 +352,6 @@ namespace odhtdb string errMsg = description; errMsg += " failed with error: "; errMsg += sqlite3_errmsg(db); - sqlite3_exec(db, "ROLLBACK", 0, 0, 0); if(rc == SQLITE_CONSTRAINT) throw DatabaseStorageAlreadyExists(errMsg); else @@ -366,7 +365,6 @@ namespace odhtdb { string errMsg = "Failed to bind param, error code: "; errMsg += to_string(sqliteBindResult); - sqlite3_exec(sqliteDb, "ROLLBACK", 0, 0, 0); throw DatabaseStorageException(errMsg); } } @@ -385,7 +383,6 @@ namespace odhtdb { string errMsg = "select Node id failed with error: "; errMsg += sqlite3_errmsg(sqliteDb); - sqlite3_exec(sqliteDb, "ROLLBACK", 0, 0, 0); throw DatabaseStorageException(errMsg); } @@ -406,7 +403,6 @@ namespace odhtdb { string errMsg = "select NodeAddData id failed with error: "; errMsg += sqlite3_errmsg(sqliteDb); - sqlite3_exec(sqliteDb, "ROLLBACK", 0, 0, 0); throw DatabaseStorageException(errMsg); } @@ -427,7 +423,7 @@ namespace odhtdb void DatabaseStorage::createStorage(const Hash &hash, const Signature::PublicKey &adminPublicKey, const DataView &adminGroupId, u64 timestamp, const void *data, usize size) { - sqlite3_exec(sqliteDb, "BEGIN", 0, 0, 0); + SqlTransaction transaction(sqliteDb); { sqlite3_reset(insertNodeStmt); sqlite3_clear_bindings(insertNodeStmt); @@ -445,7 +441,7 @@ namespace odhtdb rc = sqlite3_bind_blob(insertNodeStmt, 4, adminGroupId.data, GROUP_ID_LENGTH, SQLITE_STATIC); bindCheckError(rc); - sqlite_step_rollback_on_failure(sqliteDb, insertNodeStmt, "insert data into Node"); + sqlite_step_throw_on_failure(sqliteDb, insertNodeStmt, "insert data into Node"); addGroup(hash, adminGroupId, ADMIN_PERMISSION); addUser(hash, adminPublicKey, adminGroupId); } @@ -460,9 +456,9 @@ namespace odhtdb rc = sqlite3_bind_blob(insertNodeRawStmt, 2, data, size, SQLITE_STATIC); bindCheckError(rc); - sqlite_step_rollback_on_failure(sqliteDb, insertNodeRawStmt, "insert data into NodeRaw"); + sqlite_step_throw_on_failure(sqliteDb, insertNodeRawStmt, "insert data into NodeRaw"); } - sqlite3_exec(sqliteDb, "COMMIT", 0, 0, 0); + transaction.commit(); auto nodeDecryptionKeyResult = getNodeDecryptionKey(hash); if(nodeDecryptionKeyResult.first) @@ -471,7 +467,7 @@ namespace odhtdb void DatabaseStorage::appendStorage(const Hash &nodeHash, const Hash &dataHash, DatabaseOperation operation, u64 newUserActionCounter, const Signature::PublicKey &creatorPublicKey, u64 timestamp, const void *data, usize size, const DataView &additionalDataView) { - sqlite3_exec(sqliteDb, "BEGIN", 0, 0, 0); + SqlTransaction transaction(sqliteDb); { SqlQuery selectUserIdAndActionCounter(sqliteDb, "SELECT id, latestActionCounter FROM NodeUser WHERE node = ? AND publicKey = ?", @@ -490,7 +486,6 @@ namespace odhtdb if(newUserActionCounter == userActionCounter) { - sqlite3_exec(sqliteDb, "ROLLBACK", 0, 0, 0); throw DatabaseStorageException("Got unique package but action counter was equal to users existing one, discarding packet"); } else if(newUserActionCounter == userActionCounter + 1) @@ -586,7 +581,7 @@ namespace odhtdb rc = sqlite3_bind_int(insertNodeAddDataStmt, 7, newUserActionCounter); bindCheckError(rc); - sqlite_step_rollback_on_failure(sqliteDb, insertNodeAddDataStmt, "insert data into NodeAddData"); + sqlite_step_throw_on_failure(sqliteDb, insertNodeAddDataStmt, "insert data into NodeAddData"); } i64 nodeRowId = getNodeRowId(nodeHash); @@ -607,40 +602,31 @@ namespace odhtdb rc = sqlite3_bind_blob(insertNodeAddDataAdditionalStmt, 2, additionalDataView.data, additionalDataView.size, SQLITE_STATIC); bindCheckError(rc); - sqlite_step_rollback_on_failure(sqliteDb, insertNodeAddDataAdditionalStmt, "insert data into NodeAddDataAdditional"); + sqlite_step_throw_on_failure(sqliteDb, insertNodeAddDataAdditionalStmt, "insert data into NodeAddDataAdditional"); } else if(operation == DatabaseOperation::ADD_USER) { - try - { - sibs::SafeDeserializer deserializer((const u8*)additionalDataView.data, additionalDataView.size); - deserializer.extract((u8*)userToAddPublicKey.getData(), PUBLIC_KEY_NUM_BYTES); - deserializer.extract(groupToAddUserTo, GROUP_ID_LENGTH); - - sqlite3_reset(insertNodeAddUserDataStmt); - sqlite3_clear_bindings(insertNodeAddUserDataStmt); - - int rc; - rc = sqlite3_bind_int64(insertNodeAddUserDataStmt, 1, nodeAddRowId); - bindCheckError(rc); - - rc = sqlite3_bind_blob(insertNodeAddUserDataStmt, 2, userToAddPublicKey.getData(), PUBLIC_KEY_NUM_BYTES, SQLITE_STATIC); - bindCheckError(rc); - - rc = sqlite3_bind_blob(insertNodeAddUserDataStmt, 3, groupToAddUserTo, GROUP_ID_LENGTH, SQLITE_STATIC); - bindCheckError(rc); - - sqlite_step_rollback_on_failure(sqliteDb, insertNodeAddUserDataStmt, "insert data into NodeAddUserData"); - } - catch(sibs::DeserializeException &e) - { - sqlite3_exec(sqliteDb, "ROLLBACK", 0, 0, 0); - throw e; - } + sibs::SafeDeserializer deserializer((const u8*)additionalDataView.data, additionalDataView.size); + deserializer.extract((u8*)userToAddPublicKey.getData(), PUBLIC_KEY_NUM_BYTES); + deserializer.extract(groupToAddUserTo, GROUP_ID_LENGTH); + + sqlite3_reset(insertNodeAddUserDataStmt); + sqlite3_clear_bindings(insertNodeAddUserDataStmt); + + int rc; + rc = sqlite3_bind_int64(insertNodeAddUserDataStmt, 1, nodeAddRowId); + bindCheckError(rc); + + rc = sqlite3_bind_blob(insertNodeAddUserDataStmt, 2, userToAddPublicKey.getData(), PUBLIC_KEY_NUM_BYTES, SQLITE_STATIC); + bindCheckError(rc); + + rc = sqlite3_bind_blob(insertNodeAddUserDataStmt, 3, groupToAddUserTo, GROUP_ID_LENGTH, SQLITE_STATIC); + bindCheckError(rc); + + sqlite_step_throw_on_failure(sqliteDb, insertNodeAddUserDataStmt, "insert data into NodeAddUserData"); } else { - sqlite3_exec(sqliteDb, "ROLLBACK", 0, 0, 0); throw ("Unexpected operation type"); } @@ -658,7 +644,7 @@ namespace odhtdb rc = sqlite3_bind_blob(insertNodeAddDataRawStmt, 3, data, size, SQLITE_STATIC); bindCheckError(rc); - sqlite_step_rollback_on_failure(sqliteDb, insertNodeAddDataRawStmt, "insert data into NodeAddDataRaw"); + sqlite_step_throw_on_failure(sqliteDb, insertNodeAddDataRawStmt, "insert data into NodeAddDataRaw"); } auto nodeDecryptionKeyResult = getNodeDecryptionKey(nodeHash); @@ -670,7 +656,7 @@ namespace odhtdb decryptNodeAddUser(nodeAddRowId, nodeHash, dataHash, timestamp, &creatorPublicKey, &userToAddPublicKey, DataView(groupToAddUserTo, GROUP_ID_LENGTH), nodeDecryptionKeyResult.second); } - sqlite3_exec(sqliteDb, "COMMIT", 0, 0, 0); + transaction.commit(); } void DatabaseStorage::addGroup(const Hash &nodeHash, const DataView &groupId, const Permission &permissions) @@ -691,7 +677,7 @@ namespace odhtdb rc = sqlite3_bind_int64(insertGroupStmt, 4, permissions.getPermissionFlags()); bindCheckError(rc); - sqlite_step_rollback_on_failure(sqliteDb, insertGroupStmt, "insert data into NodeGroup"); + sqlite_step_throw_on_failure(sqliteDb, insertGroupStmt, "insert data into NodeGroup"); Log::debug("Created group %s in node %s", bin2hex((const char*)groupId.data, GROUP_ID_LENGTH).c_str(), nodeHash.toString().c_str()); } @@ -710,7 +696,7 @@ namespace odhtdb rc = sqlite3_bind_blob(insertNodeUserGroupAssocStmt, 3, groupId.data, groupId.size, SQLITE_STATIC); bindCheckError(rc); - sqlite_step_rollback_on_failure(sqliteDb, insertNodeUserGroupAssocStmt, "insert data into NodeUserGroupAssoc"); + sqlite_step_throw_on_failure(sqliteDb, insertNodeUserGroupAssocStmt, "insert data into NodeUserGroupAssoc"); } void DatabaseStorage::addUser(const Hash &nodeHash, const Signature::PublicKey &userPublicKey, const DataView &groupId) @@ -725,7 +711,7 @@ namespace odhtdb rc = sqlite3_bind_blob(insertUserStmt, 2, userPublicKey.getData(), userPublicKey.getSize(), SQLITE_STATIC); bindCheckError(rc); - sqlite_step_rollback_on_failure(sqliteDb, insertUserStmt, "insert data into NodeUser"); + sqlite_step_throw_on_failure(sqliteDb, insertUserStmt, "insert data into NodeUser"); addUserToGroup(nodeHash, userPublicKey, groupId); @@ -926,7 +912,7 @@ namespace odhtdb rc = sqlite3_bind_blob(setNodeDecryptionKeyStmt, 2, decryptionKeyView.data, decryptionKeyView.size, SQLITE_STATIC); bindCheckError(rc); - sqlite_step_rollback_on_failure(sqliteDb, setNodeDecryptionKeyStmt, "insert or replace node decryption key"); + sqlite_step_throw_on_failure(sqliteDb, setNodeDecryptionKeyStmt, "insert or replace node decryption key"); // When changing existing encryption key, do not decrypt the existing data as it has already been decrypted, // the new key should only be used for new data @@ -976,7 +962,7 @@ namespace odhtdb rc = sqlite3_bind_blob(selectNodeAddDataByNodeStmt, 1, nodeHash.getData(), nodeHash.getSize(), SQLITE_STATIC); bindCheckError(rc); - sqlite3_exec(sqliteDb, "BEGIN", 0, 0, 0); + SqlTransaction transaction(sqliteDb); bool success = true; while(true) { @@ -1010,7 +996,6 @@ namespace odhtdb { string errMsg = "select NodeAddDataAdditional failed with error: "; errMsg += sqlite3_errmsg(sqliteDb); - sqlite3_exec(sqliteDb, "ROLLBACK", 0, 0, 0); throw DatabaseStorageException(errMsg); } @@ -1038,7 +1023,6 @@ namespace odhtdb { string errMsg = "select NodeAddUserData failed with error: "; errMsg += sqlite3_errmsg(sqliteDb); - sqlite3_exec(sqliteDb, "ROLLBACK", 0, 0, 0); throw DatabaseStorageException(errMsg); } @@ -1062,11 +1046,10 @@ namespace odhtdb { string errMsg = "select NodeAddData by node failed with error: "; errMsg += sqlite3_errmsg(sqliteDb); - sqlite3_exec(sqliteDb, "ROLLBACK", 0, 0, 0); throw DatabaseStorageException(errMsg); } } - sqlite3_exec(sqliteDb, "COMMIT", 0, 0, 0); + transaction.commit(); return success; } @@ -1083,7 +1066,7 @@ namespace odhtdb rc = sqlite3_bind_int64(setNodeAddDataDecryptedStmt, 2, rowId); bindCheckError(rc); - sqlite_step_rollback_on_failure(sqliteDb, setNodeAddDataDecryptedStmt, "set NodeAddData decrypted"); + sqlite_step_throw_on_failure(sqliteDb, setNodeAddDataDecryptedStmt, "set NodeAddData decrypted"); } void DatabaseStorage::setNodeAddDataDecryptedData(i64 rowId, const DataView &decryptedData) @@ -1099,7 +1082,7 @@ namespace odhtdb rc = sqlite3_bind_int64(setNodeAddDataAdditionalDataStmt, 2, rowId); bindCheckError(rc); - sqlite_step_rollback_on_failure(sqliteDb, setNodeAddDataAdditionalDataStmt, "set NodeAddData decrypted"); + sqlite_step_throw_on_failure(sqliteDb, setNodeAddDataAdditionalDataStmt, "set NodeAddData decrypted"); } bool DatabaseStorage::decryptNodeAddData(i64 rowId, const Hash &nodeHash, const Hash &dataHash, u64 timestamp, const Signature::PublicKey *creatorPublicKey, const DataView &encryptedData, const shared_ptr<OwnedMemory> decryptionKey) diff --git a/src/sql/Sql.cpp b/src/sql/Sql.cpp index 754a30d..4e65ddb 100644 --- a/src/sql/Sql.cpp +++ b/src/sql/Sql.cpp @@ -1,5 +1,7 @@ #include "../../include/odhtdb/sql/Sql.hpp" #include <sqlite3.h> +#include <exception> +#include <cassert> namespace odhtdb { @@ -18,4 +20,22 @@ namespace odhtdb } return SQLITE_OK; } + + SqlTransaction::SqlTransaction(sqlite3 *_db) : + db(_db) + { + assert(db); + sqlite3_exec(db, "BEGIN", 0, 0, 0); + } + + SqlTransaction::~SqlTransaction() + { + if(std::uncaught_exception()) + sqlite3_exec(db, "ROLLBACK", 0, 0, 0); + } + + void SqlTransaction::commit() + { + sqlite3_exec(db, "COMMIT", 0, 0, 0); + } } diff --git a/src/sql/SqlExec.cpp b/src/sql/SqlExec.cpp index 732b2f1..bdb0fbd 100644 --- a/src/sql/SqlExec.cpp +++ b/src/sql/SqlExec.cpp @@ -12,7 +12,6 @@ namespace odhtdb { std::string errMsg = "Failed to prepare sqlite statement, error: "; errMsg += sqlite3_errmsg(db); - sqlite3_exec(db, "ROLLBACK", 0, 0, 0); throw SqlExecException(errMsg); } } @@ -37,7 +36,6 @@ namespace odhtdb errMsg += " parameters, got "; errMsg += std::to_string(args.size()); errMsg += " arguments"; - sqlite3_exec(db, "ROLLBACK", 0, 0, 0); throw SqlExecException(errMsg); } @@ -49,7 +47,6 @@ namespace odhtdb { std::string errMsg = "Failed to bind param, error code: "; errMsg += std::to_string(rc); - sqlite3_exec(db, "ROLLBACK", 0, 0, 0); throw SqlExecException(errMsg); } ++paramIndex; @@ -60,7 +57,6 @@ namespace odhtdb { std::string errMsg = "Failed to perform sql exec, error: "; errMsg += sqlite3_errmsg(db); - sqlite3_exec(db, "ROLLBACK", 0, 0, 0); throw SqlExecException(errMsg); } } diff --git a/src/sql/SqlQuery.cpp b/src/sql/SqlQuery.cpp index b99f92d..6201332 100644 --- a/src/sql/SqlQuery.cpp +++ b/src/sql/SqlQuery.cpp @@ -13,7 +13,6 @@ namespace odhtdb { std::string errMsg = "Failed to prepare sqlite statement, error: "; errMsg += sqlite3_errmsg(db); - sqlite3_exec(db, "ROLLBACK", 0, 0, 0); throw SqlQueryException(errMsg); } @@ -27,7 +26,6 @@ namespace odhtdb errMsg += " arguments"; sqlite3_finalize(stmt); stmt = nullptr; - sqlite3_exec(db, "ROLLBACK", 0, 0, 0); throw SqlQueryException(errMsg); } @@ -41,7 +39,6 @@ namespace odhtdb errMsg += std::to_string(rc); sqlite3_finalize(stmt); stmt = nullptr; - sqlite3_exec(db, "ROLLBACK", 0, 0, 0); throw SqlQueryException(errMsg); } ++paramIndex; @@ -62,7 +59,6 @@ namespace odhtdb { std::string errMsg = "Failed to perform sql select, error: "; errMsg += sqlite3_errmsg(db); - sqlite3_exec(db, "ROLLBACK", 0, 0, 0); throw SqlQueryException(errMsg); } @@ -88,7 +84,6 @@ namespace odhtdb errMsg += " has to be between 0 and "; errMsg += std::to_string(numColumns - 1); } - sqlite3_exec(db, "ROLLBACK", 0, 0, 0); throw SqlQueryException(errMsg); } } |