From 656333b11daf90da2415c2386f86668e0afa5daa Mon Sep 17 00:00:00 2001
From: kjm <2646402264@qq.com>
Date: Fri, 28 Jun 2024 12:08:04 +0800
Subject: [PATCH] =?UTF-8?q?fix:=20Context=20=E9=87=8C=E7=9A=84=20Transacti?=
 =?UTF-8?q?on=20=E6=94=B9=E7=94=A8=20shared=5Fptr=EF=BC=8C=E4=BF=AE?=
 =?UTF-8?q?=E5=A4=8D=E5=86=85=E5=AD=98=E6=B3=84=E6=BC=8F=E9=97=AE=E9=A2=98?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 CMakeLists.txt                          |  7 +++++++
 src/common/context.h                    |  2 +-
 src/execution/executor_delete.h         |  4 ++--
 src/execution/executor_index_scan.h     |  2 +-
 src/execution/executor_insert.h         |  6 +++---
 src/execution/executor_seq_scan.h       |  2 +-
 src/execution/executor_update.h         |  8 +++----
 src/rmdb.cpp                            |  5 -----
 src/system/sm_manager.cpp               | 28 ++++++++++++-------------
 src/transaction/transaction_manager.cpp | 12 +++++------
 src/transaction/transaction_manager.h   | 10 ++++-----
 test/multi_client.py                    |  6 +++---
 12 files changed, 47 insertions(+), 45 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index ef3e9a9..e0f1983 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -12,6 +12,13 @@ set(CMAKE_CXX_FLAGS "-Wall -O2")
 # set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g")
 # set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -O0 -g")
 
+# 检测内存泄漏
+# set(CMAKE_CXX_STANDARD 17)
+# set(CMAKE_CXX_FLAGS "-Wall -O0 -g -ggdb3 -fsanitize=undefined,address,leak -fno-omit-frame-pointer")
+# # set(CMAKE_CXX_FLAGS "-Wall -O2")
+
+# set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g -fsanitize=undefined,address,leak -fno-omit-frame-pointer")
+# set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -O0 -g -fsanitize=undefined,address,leak -fno-omit-frame-pointer")
 
 enable_testing()
 add_subdirectory(src)
diff --git a/src/common/context.h b/src/common/context.h
index faea9f1..bca9fc3 100644
--- a/src/common/context.h
+++ b/src/common/context.h
@@ -30,7 +30,7 @@ public:
     // TransactionManager *txn_mgr_;
     LockManager* lock_mgr_;
     LogManager* log_mgr_;
-    Transaction* txn_;
+    std::shared_ptr<Transaction> txn_;
     char* data_send_;
     int* offset_;
     bool ellipsis_;
diff --git a/src/execution/executor_delete.h b/src/execution/executor_delete.h
index 0204edf..154cf90 100644
--- a/src/execution/executor_delete.h
+++ b/src/execution/executor_delete.h
@@ -36,7 +36,7 @@ public:
         context_ = context;
 
         // 获取表锁:X锁
-        context_->lock_mgr_->lock_exclusive_on_table(context->txn_, fh_->GetFd());
+        context_->lock_mgr_->lock_exclusive_on_table(context->txn_.get(), fh_->GetFd());
     }
 
     std::unique_ptr<RmRecord> Next() override {
@@ -49,7 +49,7 @@ public:
                 auto& index = tab_.indexes[i];
                 auto ih = sm_manager_->ihs_.at(index.index_name).get();
                 auto delete_key = index.get_key(rec->data);
-                ih->delete_entry(delete_key.get(), context_->txn_);
+                ih->delete_entry(delete_key.get(), context_->txn_.get());
             }
             fh_->delete_record(rid, context_);
             // add write record
diff --git a/src/execution/executor_index_scan.h b/src/execution/executor_index_scan.h
index 16f38b6..225830e 100644
--- a/src/execution/executor_index_scan.h
+++ b/src/execution/executor_index_scan.h
@@ -68,7 +68,7 @@ public:
         isend_ = false;
 
         // 获取表锁:S锁
-        context_->lock_mgr_->lock_shared_on_table(context->txn_, fh_->GetFd());
+        context_->lock_mgr_->lock_shared_on_table(context->txn_.get(), fh_->GetFd());
     }
 
     ColMeta get_col_offset(const TabCol& target) {
diff --git a/src/execution/executor_insert.h b/src/execution/executor_insert.h
index 235590e..1e67281 100644
--- a/src/execution/executor_insert.h
+++ b/src/execution/executor_insert.h
@@ -37,7 +37,7 @@ public:
         context_ = context;
 
         // 获取表锁:X锁
-        context_->lock_mgr_->lock_exclusive_on_table(context->txn_, fh_->GetFd());
+        context_->lock_mgr_->lock_exclusive_on_table(context->txn_.get(), fh_->GetFd());
     };
 
     std::unique_ptr<RmRecord> Next() override {
@@ -66,7 +66,7 @@ public:
             auto& index = tab_.indexes[i];
             auto ih = sm_manager_->ihs_.at(index.index_name).get();
             auto insert_key = index.get_key(rec.data);
-            if (ih->get_value(insert_key.get(), &result, context_->txn_)) {
+            if (ih->get_value(insert_key.get(), &result, context_->txn_.get())) {
                 throw DuplicateKeyError(index.index_name, tab_.name);
             }
         }
@@ -78,7 +78,7 @@ public:
             auto& index = tab_.indexes[i];
             auto ih = sm_manager_->ihs_.at(index.index_name).get();
             auto insert_key = index.get_key(rec.data);
-            ih->insert_entry(insert_key.get(), rid_, context_->txn_);
+            ih->insert_entry(insert_key.get(), rid_, context_->txn_.get());
         }
 
         // add write record
diff --git a/src/execution/executor_seq_scan.h b/src/execution/executor_seq_scan.h
index f47a838..f98fd2e 100644
--- a/src/execution/executor_seq_scan.h
+++ b/src/execution/executor_seq_scan.h
@@ -47,7 +47,7 @@ public:
         fed_conds_ = conds_;
 
         // 获取表锁:S锁
-        context_->lock_mgr_->lock_shared_on_table(context->txn_, fh_->GetFd());
+        context_->lock_mgr_->lock_shared_on_table(context->txn_.get(), fh_->GetFd());
     }
 
     ColMeta get_col_offset(const TabCol& target) {
diff --git a/src/execution/executor_update.h b/src/execution/executor_update.h
index f54744a..a03400d 100644
--- a/src/execution/executor_update.h
+++ b/src/execution/executor_update.h
@@ -38,7 +38,7 @@ public:
         context_ = context;
 
         // 获取表锁:X锁
-        context_->lock_mgr_->lock_exclusive_on_table(context->txn_, fh_->GetFd());
+        context_->lock_mgr_->lock_exclusive_on_table(context->txn_.get(), fh_->GetFd());
     }
     std::unique_ptr<RmRecord> Next() override {
         // 处理set_clause
@@ -74,7 +74,7 @@ public:
                     continue;
                 }
                 std::vector<Rid> result;
-                if (ih->get_value(new_key.get(), &result, context_->txn_)) {
+                if (ih->get_value(new_key.get(), &result, context_->txn_.get())) {
                     throw DuplicateKeyError(index.index_name, tab_.name);
                 }
             }
@@ -89,8 +89,8 @@ public:
                 if (memcmp(old_key.get(), new_key.get(), index.col_tot_len) == 0) {
                     continue;
                 }
-                ih->delete_entry(old_key.get(), context_->txn_);
-                ih->insert_entry(new_key.get(), rid, context_->txn_);
+                ih->delete_entry(old_key.get(), context_->txn_.get());
+                ih->insert_entry(new_key.get(), rid, context_->txn_.get());
             }
 
             // add write record
diff --git a/src/rmdb.cpp b/src/rmdb.cpp
index 4ed72f5..a56efed 100644
--- a/src/rmdb.cpp
+++ b/src/rmdb.cpp
@@ -67,10 +67,6 @@ void SetTransaction(txn_id_t* txn_id, Context* context) {
     context->txn_ = txn_manager->get_transaction(*txn_id);
     if (context->txn_ == nullptr || context->txn_->get_state() == TransactionState::COMMITTED ||
         context->txn_->get_state() == TransactionState::ABORTED) {
-        if (context->txn_) {
-            delete context->txn_;
-            context->txn_ = nullptr;
-        }
         context->txn_ = txn_manager->begin(context);
         *txn_id = context->txn_->get_transaction_id();
         context->txn_->set_txn_mode(false);
@@ -355,7 +351,6 @@ int main(int argc, char** argv) {
         log_manager->start_flush_thread();
 
         delete context;
-        delete txn;
 
         // 开启服务端,开始接受客户端连接
         start_server();
diff --git a/src/system/sm_manager.cpp b/src/system/sm_manager.cpp
index d8a5ea2..a9044ce 100644
--- a/src/system/sm_manager.cpp
+++ b/src/system/sm_manager.cpp
@@ -257,7 +257,7 @@ void SmManager::create_table(const std::string& tab_name, const std::vector<ColD
     flush_meta();
 
     // 加上表锁
-    context->lock_mgr_->lock_exclusive_on_table(context->txn_, fhs_.at(tab_name)->GetFd());
+    context->lock_mgr_->lock_exclusive_on_table(context->txn_.get(), fhs_.at(tab_name)->GetFd());
 }
 
 /**
@@ -270,7 +270,7 @@ void SmManager::drop_table(const std::string& tab_name, Context* context) {
         throw TableNotFoundError(tab_name);
     }
     // 加上表锁
-    context->lock_mgr_->lock_exclusive_on_table(context->txn_, fhs_.at(tab_name)->GetFd());
+    context->lock_mgr_->lock_exclusive_on_table(context->txn_.get(), fhs_.at(tab_name)->GetFd());
 
     // drop index
     TabMeta& tab = db_.get_table(tab_name);
@@ -308,7 +308,7 @@ void SmManager::create_index(const std::string& tab_name, const std::vector<std:
     }
 
     // 加上表锁
-    context->lock_mgr_->lock_shared_on_table(context->txn_, fhs_.at(tab_name)->GetFd());
+    context->lock_mgr_->lock_shared_on_table(context->txn_.get(), fhs_.at(tab_name)->GetFd());
 
     // 创建 index meta
     std::vector<ColMeta> col_metas;
@@ -333,14 +333,14 @@ void SmManager::create_index(const std::string& tab_name, const std::vector<std:
         auto key = index.get_key(rec->data);
         // 检查是否有重复的key
         std::vector<Rid> result;
-        if (ih->get_value(key.get(), &result, context->txn_)) {
+        if (ih->get_value(key.get(), &result, context->txn_.get())) {
             // 有重复的,删掉创建的索引
             ix_manager_->close_index(ih.get());
             buffer_pool_manager_->delete_pages(ih->get_fd());
             ix_manager_->destroy_index(tab_name, col_names);
             throw DuplicateKeyError(index.index_name, tab_name);
         }
-        ih->insert_entry(key.get(), rm_scan.rid(), context->txn_);
+        ih->insert_entry(key.get(), rm_scan.rid(), context->txn_.get());
     }
     ihs_[index.index_name] = std::move(ih);
     tab.indexes.push_back(std::move(index));
@@ -359,7 +359,7 @@ void SmManager::drop_index(const std::string& tab_name, const std::vector<std::s
     }
     TabMeta& tab = db_.get_table(tab_name);
     // 加上表锁
-    context->lock_mgr_->lock_exclusive_on_table(context->txn_, fhs_.at(tab_name)->GetFd());
+    context->lock_mgr_->lock_exclusive_on_table(context->txn_.get(), fhs_.at(tab_name)->GetFd());
     // 删除索引
     tab.drop_index(col_names);
     auto index_name = ix_manager_->get_index_name(tab_name, col_names);
@@ -402,7 +402,7 @@ void SmManager::rollback_insert(const std::string& tab_name, const Rid& rid, Con
     for (auto& index : tab.indexes) {
         auto ih = ihs_.at(index.index_name).get();
         auto key = index.get_key(rec->data);
-        ih->delete_entry(key.get(), context->txn_);
+        ih->delete_entry(key.get(), context->txn_.get());
     }
     // delete record
     fhs_.at(tab_name)->delete_record(rid, context);
@@ -422,7 +422,7 @@ void SmManager::rollback_delete(const std::string& tab_name, const Rid& rid, con
     for (auto& index : tab.indexes) {
         auto ih = ihs_.at(index.index_name).get();
         auto key = index.get_key(record.data);
-        ih->insert_entry(key.get(), rid, context->txn_);
+        ih->insert_entry(key.get(), rid, context->txn_.get());
     }
     // insert record
     fhs_.at(tab_name)->insert_record(rid, record.data, context);
@@ -443,13 +443,13 @@ void SmManager::rollback_update(const std::string& tab_name, const Rid& rid, con
     for (auto& index : tab.indexes) {
         auto ih = ihs_.at(index.index_name).get();
         auto key = index.get_key(new_rec_ptr->data);
-        ih->delete_entry(key.get(), context->txn_);
+        ih->delete_entry(key.get(), context->txn_.get());
     }
     // insert old key in index
     for (auto& index : tab.indexes) {
         auto ih = ihs_.at(index.index_name).get();
         auto key = index.get_key(record.data);
-        ih->insert_entry(key.get(), rid, context->txn_);
+        ih->insert_entry(key.get(), rid, context->txn_.get());
     }
     // update record
     fhs_.at(tab_name)->update_record(rid, record.data, context);
@@ -463,7 +463,7 @@ void SmManager::redo_insert(const std::string &tab_name, const Rid &rid, const R
     for (auto &index : tab.indexes) {
         auto ih = ihs_.at(index.index_name).get();
         auto key = index.get_key(record.data);
-        ih->insert_entry(key.get(), rid, context->txn_);
+        ih->insert_entry(key.get(), rid, context->txn_.get());
     }
 }
 
@@ -475,7 +475,7 @@ void SmManager::redo_delete(const std::string &tab_name, const Rid &rid, const R
     for (auto &index : tab.indexes) {
         auto ih = ihs_.at(index.index_name).get();
         auto key = index.get_key(record.data);
-        ih->delete_entry(key.get(), context->txn_);
+        ih->delete_entry(key.get(), context->txn_.get());
     }
 }
 
@@ -488,7 +488,7 @@ void SmManager::redo_update(const std::string &tab_name, const Rid &rid, const R
         auto ih = ihs_.at(index.index_name).get();
         auto delete_key = index.get_key(old_record.data);
         auto insert_key = index.get_key(new_record.data);
-        ih->delete_entry(delete_key.get(), context->txn_);
-        ih->insert_entry(insert_key.get(), rid, context->txn_);
+        ih->delete_entry(delete_key.get(), context->txn_.get());
+        ih->insert_entry(insert_key.get(), rid, context->txn_.get());
     }
 }
\ No newline at end of file
diff --git a/src/transaction/transaction_manager.cpp b/src/transaction/transaction_manager.cpp
index 11ca414..c55221e 100644
--- a/src/transaction/transaction_manager.cpp
+++ b/src/transaction/transaction_manager.cpp
@@ -12,21 +12,21 @@ See the Mulan PSL v2 for more details. */
 #include "record/rm_file_handle.h"
 #include "system/sm_manager.h"
 
-std::unordered_map<txn_id_t, Transaction*> TransactionManager::txn_map = {};
+std::unordered_map<txn_id_t, std::shared_ptr<Transaction>> TransactionManager::txn_map = {};
 
 /**
  * @description: 事务的开始方法
- * @return {Transaction*} 开始事务的指针
+ * @return {std::shared_ptr<Transaction>} 开始事务的指针
  * @param {Context *} context 事务上下文
  */
-Transaction* TransactionManager::begin(Context* context) {
+std::shared_ptr<Transaction> TransactionManager::begin(Context* context) {
     // Todo:
     // 1. 判断传入事务参数是否为空指针
     // 2. 如果为空指针,创建新事务
     // 3. 把开始事务加入到全局事务表中
     // 4. 返回当前事务指针
     if (context->txn_ == nullptr) {
-        context->txn_ = new Transaction(next_txn_id_++);
+        context->txn_ = std::make_shared<Transaction>(next_txn_id_++);
         // txn->set_state(TransactionState::DEFAULT);
     }
     // begin log
@@ -67,7 +67,7 @@ void TransactionManager::commit(Context* context) {
     }
 
     // 释放所有锁
-    lock_manager_->unlock_all(context->txn_);
+    lock_manager_->unlock_all(context->txn_.get());
 
     // 日志刷盘
     context->log_mgr_->can_flush();
@@ -123,7 +123,7 @@ void TransactionManager::abort(Context* context) {
         context->txn_->set_prev_lsn(context->log_mgr_->add_log_to_buffer(&log_record));
     }
     // 释放所有锁
-    lock_manager_->unlock_all(context->txn_);
+    lock_manager_->unlock_all(context->txn_.get());
     
     // 日志刷盘
     context->log_mgr_->can_flush();
diff --git a/src/transaction/transaction_manager.h b/src/transaction/transaction_manager.h
index aa30100..e781105 100644
--- a/src/transaction/transaction_manager.h
+++ b/src/transaction/transaction_manager.h
@@ -30,7 +30,7 @@ public:
 
     ~TransactionManager() = default;
 
-    Transaction* begin(Context* context);
+    std::shared_ptr<Transaction> begin(Context* context);
 
     void commit(Context* context);
 
@@ -48,10 +48,10 @@ public:
 
     /**
      * @description: 获取事务ID为txn_id的事务对象
-     * @return {Transaction*} 事务对象的指针
+     * @return {std::shared_ptr<Transaction>} 事务对象的指针
      * @param {txn_id_t} txn_id 事务ID
      */
-    Transaction* get_transaction(txn_id_t txn_id) {
+    std::shared_ptr<Transaction> get_transaction(txn_id_t txn_id) {
         if (txn_id == INVALID_TXN_ID)
             return nullptr;
 
@@ -59,7 +59,7 @@ public:
         // assert(TransactionManager::txn_map.find(txn_id) != TransactionManager::txn_map.end());
         if (TransactionManager::txn_map.find(txn_id) == TransactionManager::txn_map.end())
             return nullptr;
-        auto* res = TransactionManager::txn_map[txn_id];
+        auto res = TransactionManager::txn_map[txn_id];
         lock.unlock();
         assert(res != nullptr);
         assert(res->get_thread_id() == std::this_thread::get_id());
@@ -67,7 +67,7 @@ public:
         return res;
     }
 
-    static std::unordered_map<txn_id_t, Transaction*> txn_map; // 全局事务表,存放事务ID与事务对象的映射关系
+    static std::unordered_map<txn_id_t, std::shared_ptr<Transaction>> txn_map; // 全局事务表,存放事务ID与事务对象的映射关系
     std::shared_mutex txn_lock_;                               // 用于在创建日志checkpoint阻塞所有事务
 
 private:
diff --git a/test/multi_client.py b/test/multi_client.py
index 55e714b..6422a1a 100644
--- a/test/multi_client.py
+++ b/test/multi_client.py
@@ -33,8 +33,8 @@ def test(thread_id, host, port):
     try:
         # 连接到服务器
         client_socket.connect((host, port))
-
-        for i in range(100):
+        n = 1000
+        for i in range(n):
             message = 'insert into t values (%d);' % (i*5 + thread_id)
             send(client_socket, message)
         
@@ -48,7 +48,7 @@ def test(thread_id, host, port):
 
         time.sleep(5)
 
-        for i in range(100):
+        for i in range(n):
             message = 'delete from t where id=%d;' % (i*5 + thread_id)
             send(client_socket, message)
         
-- 
GitLab