#include "../../include/odhtdb/sql/SqlQuery.hpp" #include namespace odhtdb { int SqlArg::bind(sqlite3_stmt *stmt, int paramIndex) const { switch(type) { case Type::DATA_VIEW: return sqlite3_bind_blob(stmt, paramIndex, dataView.data, dataView.size, SQLITE_STATIC); case Type::INT: return sqlite3_bind_int(stmt, paramIndex, integer); case Type::INT64: return sqlite3_bind_int64(stmt, paramIndex, integer64); } return SQLITE_OK; } SqlQuery::SqlQuery(sqlite3 *_db, const char *sql, std::initializer_list args) : db(_db), stmt(nullptr), numColumns(0) { int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr); if(rc != SQLITE_OK) { std::string errMsg = "Failed to prepare sqlite statement, error: "; errMsg += sqlite3_errmsg(db); sqlite3_exec(db, "ROLLBACK", 0, 0, 0); throw SqlQueryException(errMsg); } int numParams = sqlite3_bind_parameter_count(stmt); if(args.size() != numParams) { std::string errMsg = "Failed to prepare sqlite statement, error: Sql has "; errMsg += std::to_string(numParams); errMsg += " parameters, got "; errMsg += std::to_string(args.size()); errMsg += " arguments"; sqlite3_finalize(stmt); stmt = nullptr; sqlite3_exec(db, "ROLLBACK", 0, 0, 0); throw SqlQueryException(errMsg); } int paramIndex = 1; for(const SqlArg &arg : args) { rc = arg.bind(stmt, paramIndex); if(rc != SQLITE_OK) { std::string errMsg = "Failed to bind param, error code: "; errMsg += std::to_string(rc); sqlite3_finalize(stmt); stmt = nullptr; sqlite3_exec(db, "ROLLBACK", 0, 0, 0); throw SqlQueryException(errMsg); } ++paramIndex; } } SqlQuery::~SqlQuery() { sqlite3_finalize(stmt); } bool SqlQuery::next() { int rc = sqlite3_step(stmt); if(rc == SQLITE_DONE) return false; else if(rc != SQLITE_ROW) { std::string errMsg = "Failed to perform sql select, error: "; errMsg += sqlite3_errmsg(db); sqlite3_exec(db, "ROLLBACK", 0, 0, 0); throw SqlQueryException(errMsg); } numColumns = sqlite3_data_count(stmt); return true; } void SqlQuery::checkColumnIndex(int index) { if(index < 0 || index >= numColumns) { std::string errMsg; if(numColumns == 0) { errMsg += "Attempt to get column "; errMsg += std::to_string(index); errMsg += " but result does not have any columns"; } else { errMsg += "Column index "; errMsg += std::to_string(index); errMsg += " has to be between 0 and "; errMsg += std::to_string(numColumns - 1); } sqlite3_exec(db, "ROLLBACK", 0, 0, 0); throw SqlQueryException(errMsg); } } int SqlQuery::getInt(int index) { checkColumnIndex(index); return sqlite3_column_int(stmt, index); } i64 SqlQuery::getInt64(int index) { checkColumnIndex(index); return sqlite3_column_int64(stmt, index); } const DataView SqlQuery::getBlob(int index) { checkColumnIndex(index); const void *data = sqlite3_column_blob(stmt, index); int size = sqlite3_column_bytes(stmt, index); return { (void*)data, (usize)size }; } }