From d9480b4a741078505bf6ab022b4453ad0f4003ff Mon Sep 17 00:00:00 2001
From: lotus_grow <531096131@qq.com>
Date: Thu, 6 Jun 2024 22:03:39 +0800
Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9Agroup=20by=20=E5=AE=8C=E6=88=90?=
 =?UTF-8?q?=EF=BC=8C=E4=BD=86free=E6=97=A0=E6=95=88=E5=86=85=E5=AD=98?=
 =?UTF-8?q?=E5=9C=B0=E5=9D=80?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 src/analyze/analyze.cpp             |  11 +-
 src/common/common.h                 |  12 +++
 src/execution/execution_defs.h      |   6 ++
 src/execution/executor_group.h      | 154 +++++++++++++++++++++++++---
 src/execution/executor_projection.h |  29 ++++--
 test/4.1-more.txt                   |   1 +
 6 files changed, 188 insertions(+), 25 deletions(-)

diff --git a/src/analyze/analyze.cpp b/src/analyze/analyze.cpp
index 47d9b29..9f97766 100644
--- a/src/analyze/analyze.cpp
+++ b/src/analyze/analyze.cpp
@@ -126,15 +126,22 @@ std::shared_ptr<Query> Analyze::do_analyze(std::shared_ptr<ast::TreeNode> parse)
                 if (auto rhs_val = std::dynamic_pointer_cast<ast::Value>(cond->rhs)) {
                     aggr_cond.is_rhs_val = true;
                     aggr_cond.rhs_val = convert_sv_value(rhs_val);
+                    if (aggr_cond.rhs_val.type == TYPE_STRING) {
+                        throw GroupByError("The right side of a conditional expression in HAVING clause cannot be a string");
+                    }
                 } else if (auto rhs_col = std::dynamic_pointer_cast<ast::Col>(cond->rhs)) {
-                    aggr_cond.is_rhs_val = false;
-                    aggr_cond.rhs_col = {.tab_name = rhs_col->tab_name, .col_name = rhs_col->col_name};
+                    // aggr_cond.is_rhs_val = false;
+                    // aggr_cond.rhs_col = {.tab_name = rhs_col->tab_name, .col_name = rhs_col->col_name};
+                    throw GroupByError("The right side of a conditional expression in HAVING clause cannot be an identifier");
                 }
                 query->aggr_conds.push_back(aggr_cond);
             }
             // 检查having条件格式
             check_clause(query->tables, query->aggr_conds);
             for (auto& aggr_cond : query->aggr_conds) {
+                // 特判count(*)
+                if (aggr_cond.lhs_col.type == COUNT && aggr_cond.lhs_col.tab_name.empty() && aggr_cond.lhs_col.col_name.empty()) 
+                    continue;
                 TabMeta& tab = sm_manager_->db_.get_table(aggr_cond.lhs_col.tab_name);
                 ColType coltype = tab.get_col(aggr_cond.lhs_col.col_name)->type;
                 check_aggr_col(coltype, aggr_cond.lhs_col.type);
diff --git a/src/common/common.h b/src/common/common.h
index f92e817..430de78 100644
--- a/src/common/common.h
+++ b/src/common/common.h
@@ -39,6 +39,18 @@ struct Value {
 
     std::shared_ptr<RmRecord> raw; // raw record buffer
 
+    bool operator==(const Value& a) const { 
+        if (this->type != a.type) {
+            throw IncompatibleTypeError(coltype2str(this->type), coltype2str(a.type));
+        }
+        if (this->type == TYPE_INT) {
+            return this->int_val == a.int_val;
+        } else if (this->type == TYPE_FLOAT) {
+            return this->float_val == a.float_val;
+        } 
+        return this->str_val == a.str_val; 
+    }
+
     void set_int(int int_val_) {
         type = TYPE_INT;
         int_val = int_val_;
diff --git a/src/execution/execution_defs.h b/src/execution/execution_defs.h
index e786336..bf24b05 100644
--- a/src/execution/execution_defs.h
+++ b/src/execution/execution_defs.h
@@ -12,3 +12,9 @@ See the Mulan PSL v2 for more details. */
 
 #include "defs.h"
 #include "errors.h"
+
+struct Aggr_Col {
+    size_t sel_col_idx;   // 字段在上一层的idx
+    size_t col_idx;       // 字段在当前层的idx
+    AggrType type;
+};
\ No newline at end of file
diff --git a/src/execution/executor_group.h b/src/execution/executor_group.h
index 529352c..ab61d89 100644
--- a/src/execution/executor_group.h
+++ b/src/execution/executor_group.h
@@ -10,12 +10,12 @@ namespace std {
     template <> //function-template-specialization
         class hash<Value>{
         public :
-            size_t operator()(const Value &value) const{
-                if (value.type == TYPE_STRING) 
-                    return hash<string>()(value.str_val);
-                else if (value.type == TYPE_INT)
+            size_t operator()(const Value &value) const {
+                if (value.type == TYPE_INT)
                     return hash<int>()(value.int_val);
-                return hash<float>()(value.float_val);
+                else if (value.type == TYPE_FLOAT)
+                    return hash<float>()(value.float_val);
+                return hash<string>()(value.str_val);
             }
     };
 };
@@ -26,38 +26,160 @@ private:
     std::unique_ptr<AbstractExecutor> prev_;       // 上一层executor
     ColMeta groupby_;                              // 分组的列元数据
     std::vector<Condition> conds_;                 // having条件
-    std::vector<ColMeta> cols_;                    // group后生成的记录的字段
+    std::vector<ColMeta> prev_cols_;               // 上一层记录
+    std::vector<ColMeta> new_cols_;                // group后生成的记录的字段 : 聚合函数列记录 + having 条件中的聚合函数列
     size_t len_;                                   // group后生成的每条记录的长度
-    std::vector<TabCol> aggr_cols_;                // 包含聚合函数的字段
+    std::vector<Aggr_Col> aggr_cols_;              // 包含聚合函数的字段
 
-    std::vector<std::vector<Value>> groups_;       // 分组结果: <分组依据 + 聚合函数值(按顺序)>
-    std::unordered_map<Value, int> groups_idx_;    // 在vector中的索引
+    std::vector<RmRecord> groups_;                 // 分组结果: <聚合函数值(按顺序)+ 分组依据>
+    std::unordered_map<Value, size_t> groups_idx_; // 在vector中的索引
     size_t scan_idx_;                              // 下一层扫描时使用
+    size_t having_idx_;                            // having中聚合函数列在cols中的起始位置
 
 public:
     GroupExecutor(std::unique_ptr<AbstractExecutor> prev, TabCol groupby, std::vector<TabCol> aggr_cols, std::vector<Condition> conds, Context* context) {
         prev_ = std::move(prev);
-        cols_ = std::move(prev_->cols());
-        groupby_ = *get_col(cols_, groupby);
         conds_ = std::move(conds);
         context_ = context;
-        aggr_cols_ = std::move(aggr_cols);
-        len_ = prev_->tupleLen();
+        
+        size_t curr_offset = 0;
+        prev_cols_ = std::move(prev_->cols());
+        memcpy(&groupby_, &(*get_col(prev_cols_, groupby)), sizeof(ColMeta));
+
+        // 添加聚合函数列
+        for (auto& aggr_col : aggr_cols) {
+            auto pos = get_col(prev_cols_, aggr_col);
+            Aggr_Col aggr_col_ = {
+                .sel_col_idx = pos - prev_cols_.begin(), 
+                .col_idx = new_cols_.size(), 
+                .type = aggr_col.type
+            };
+            aggr_cols_.push_back(aggr_col_);
+
+            ColMeta col;
+            memcpy(&col, &(*pos), sizeof(ColMeta));
+            // 处理COUNT
+            if (aggr_col.type == COUNT) {
+                col.len = sizeof(int);
+                col.type = TYPE_INT;
+            }
+            col.offset = curr_offset;
+            curr_offset += col.len;
+            new_cols_.push_back(col);
+        }
+
+        // 添加having中的聚合函数列
+        having_idx_ = new_cols_.size();
+        for (auto& cond_ : conds_) {
+            auto pos = get_col(prev_cols_, cond_.lhs_col);
+            Aggr_Col aggr_col_ = {
+                .sel_col_idx = pos - prev_cols_.begin(), 
+                .col_idx = prev_cols_.size(), 
+                .type = cond_.lhs_col.type
+            };
+            aggr_cols_.push_back(aggr_col_);
+
+            ColMeta col;
+            memcpy(&col, &(*pos), sizeof(ColMeta));
+            // 处理COUNT
+            if (cond_.lhs_col.type == COUNT) {
+                col.len = sizeof(int);
+                col.type = TYPE_INT;
+            }
+            col.offset = curr_offset;
+            curr_offset += col.len;
+            new_cols_.push_back(col);
+        }
+
+        len_ = curr_offset;
 
         scan_idx_ = 0;
     }
-    const std::vector<ColMeta>& cols() const { return cols_; }
+
+    const std::vector<ColMeta>& cols() const { return prev_cols_; }
+
+    void nextMatch() {
+        while (!is_end()) {
+            bool ok = true;
+            size_t idx = having_idx_;
+            for (auto& cond : conds_) {
+                // 读取左字段的值
+                ColType ltype = new_cols_[idx].type;
+                char* lbuff = groups_[scan_idx_].data + new_cols_[idx].offset;
+                int left_len = new_cols_[idx].len;
+                idx++;
+                // 读取右字段的值
+                ColType rtype = cond.rhs_val.type;
+                char* rbuff = cond.rhs_val.raw->data;
+                int right_len = cond.rhs_val.raw->size;
+                int op = comp_lhs_rhs(ltype, lbuff, rtype, rbuff, left_len, right_len);
+                ok = is_match(cond.op, op);
+                if (!ok)
+                    break;
+            }
+
+            if (ok)
+                break;
+            ++scan_idx_;
+        }
+    }
+
     void beginTuple() override {
         // 遍历上一层所有tuple
         for (prev_->beginTuple(); !prev_->is_end(); prev_->nextTuple()) {
             auto tuple = prev_->Next();
+            // 获取group by 列的值
+            char* groupby_pos = tuple->data + groupby_.offset;
+            Value groupby_val;
+            if (groupby_.type == TYPE_INT) {
+                groupby_val.set_int(*(int*)groupby_pos);
+            } else if (groupby_.type == TYPE_FLOAT) {
+                groupby_val.set_float(*(float*)groupby_pos);
+            } else {
+                std::string str = "";
+                for (int i = 0; i < groupby_.len; i++) {
+                    str.push_back(groupby_pos[i]);
+                }
+                groupby_val.set_str(str);
+            }
+
+            // 获取tuple分在哪一组
+            if (!groups_idx_.count(groupby_val)) {
+                groups_idx_[groupby_val] = groups_.size();
+                groups_.emplace_back(len_ + groupby_.len);
+                // 初始化
+                memcpy(groups_.back().data + len_, groupby_pos, groupby_.len);
+                for (auto& aggr_col : aggr_cols_) {
+                    groups_.back().initialize_aggr(aggr_col.type, 
+                                                   new_cols_[aggr_col.col_idx].type,
+                                                   new_cols_[aggr_col.col_idx].offset);
+                }
+            }
+            size_t idx = groups_idx_[groupby_val];
 
+            // 更新聚合函数值
+            for (auto& aggr_col : aggr_cols_) {
+                groups_[idx].update(aggr_col.type, 
+                                    new_cols_[aggr_col.col_idx].type, 
+                                    new_cols_[aggr_col.col_idx].offset, 
+                                    tuple->data + prev_cols_[aggr_col.sel_col_idx].offset);
+            }
         }
+
+        scan_idx_ = 0;
+        // 指向第一个符合条件的Tuple
+        nextMatch();
     }
 
-    void nextTuple() override { ++scan_idx_; }
+    void nextTuple() override { 
+        // 跳过上一个符合条件的Tuple
+        ++scan_idx_; 
+        nextMatch();
+    }
 
-    std::unique_ptr<RmRecord> Next() override { return nullptr; }
+    std::unique_ptr<RmRecord> Next() override { 
+        return std::make_unique<RmRecord>(groups_[scan_idx_]); 
+    }
 
     bool is_end() const { return scan_idx_ == groups_.size(); }
 
diff --git a/src/execution/executor_projection.h b/src/execution/executor_projection.h
index 87554df..7311d7f 100644
--- a/src/execution/executor_projection.h
+++ b/src/execution/executor_projection.h
@@ -15,11 +15,6 @@ See the Mulan PSL v2 for more details. */
 #include "index/ix.h"
 #include "system/sm.h"
 
-struct Aggr_Col {
-    size_t sel_col_idx;   // 字段在上一层的idx
-    size_t col_idx;       // 字段在当前层的idx
-    AggrType type;
-};
 
 class ProjectionExecutor : public AbstractExecutor {
 private:
@@ -31,7 +26,8 @@ private:
     bool is_only_count_;                     // 只有count列
     bool is_exist_aggr_;                     // 是否存在聚合函数
     bool has_group_;
-    std::vector<Aggr_Col> aggr_cols_;         // 聚合函数列     
+    std::vector<Aggr_Col> aggr_cols_;        // 聚合函数列 
+    std::vector<size_t> idxs_;               // 非聚合函数列位置   
 
 public:
     ProjectionExecutor(std::unique_ptr<AbstractExecutor> prev, const std::vector<TabCol>& sel_cols, bool has_group = false) {
@@ -79,7 +75,9 @@ public:
                     .type = sel_col.type
                 };
                 aggr_cols_.push_back(aggr_col);
-            } 
+            } else {
+                idxs_.push_back(sel_idxs_.size() - 1);
+            }
         }
         len_ = curr_offset;
 
@@ -114,6 +112,23 @@ public:
 
         if (has_group_) {
             // 按照group by结果进行投影
+            std::unique_ptr<RmRecord> prev_rec = prev_->Next();
+            
+            // 遍历聚合函数列
+            size_t prev_offset = 0;
+            for (auto& aggr_col : aggr_cols_) {
+                // 取出对应字段
+                ColMeta col = cols_[aggr_col.col_idx];
+                memcpy(rec->data + col.offset, prev_rec->data + prev_offset, col.len);
+                prev_offset += col.len;
+            }
+            
+            // 寻找所有 groupby 字段的位置,并进行投影
+            for (size_t idx : idxs_) {
+                // 取出对应字段
+                ColMeta col = cols_[idx];
+                memcpy(rec->data + col.offset, prev_rec->data + prev_offset, col.len);
+            }
 
         } else if (is_exist_aggr_) {
             // 没有group by 但是存在聚合函数
diff --git a/test/4.1-more.txt b/test/4.1-more.txt
index dcea090..acfd91a 100644
--- a/test/4.1-more.txt
+++ b/test/4.1-more.txt
@@ -18,6 +18,7 @@ select COUNT(course) as course_num from grade;
 select COUNT(*) as row_num from grade;
 select SUM(score) as sum_score from grade where id = 1;
 select SUM(score) as sum_score, MIN(score) as min_score, MAX(score) as max_score from grade where id = 1;
+select id, count(*) as row_num, course from grade where id = 2;
 drop table grade;
 
 create table t ( id int , t_name char (3));
-- 
GitLab