[Commits] f225b4a2bf4: MDEV-14564: Support FOR loop in stored aggregate functions

Varun varunraiko1803 at gmail.com
Fri Dec 29 11:06:36 EET 2017


revision-id: f225b4a2bf45757e4e00c33e3e10487dd6b059a4 (mariadb-10.3.0-341-gf225b4a2bf4)
parent(s): 7a66e0ab8f52f3bd32850463daa05f9a2401e6b1
author: Varun Gupta
committer: Varun Gupta
timestamp: 2017-12-29 14:36:15 +0530
message:

MDEV-14564: Support FOR loop in stored aggregate functions

Added for loop in custom aggregate function. A special defination for FOR loop is
required for this and that is FOR GROUP NEXT ROW. This would fetch rows one by one
and when all rows are fetched we would exit from the loop.

---
 mysql-test/r/custom_aggregate_functions.result | 26 +++++++++++++++++
 mysql-test/t/custom_aggregate_functions.test   | 22 +++++++++++++++
 sql/sp_head.cc                                 | 18 ++++++++----
 sql/sp_head.h                                  |  7 +++--
 sql/sp_rcontext.cc                             |  2 +-
 sql/sp_rcontext.h                              |  2 ++
 sql/sql_lex.cc                                 | 39 ++++++++++++++++++++++++++
 sql/sql_lex.h                                  | 13 +++++++--
 sql/sql_yacc.yy                                |  6 +++-
 9 files changed, 124 insertions(+), 11 deletions(-)

diff --git a/mysql-test/r/custom_aggregate_functions.result b/mysql-test/r/custom_aggregate_functions.result
index c2fdc271140..41e4c776185 100644
--- a/mysql-test/r/custom_aggregate_functions.result
+++ b/mysql-test/r/custom_aggregate_functions.result
@@ -936,3 +936,29 @@ drop function f4;
 drop function f5;
 drop function f6;
 drop function f7;
+#
+# MDEV-14564: Support FOR loop in stored aggregate functions
+#
+CREATE aggregate FUNCTION f1 (a INT) RETURNS INT
+BEGIN
+DECLARE total INT DEFAULT 0;
+FOR GROUP NEXT ROW DO
+SET total= total + a;
+END FOR;
+return total;
+END /
+show function code f1;
+Pos	Instruction
+0	set total at 1 0
+1	jump_if_not 5(5) 1
+2	
+3	set total at 1 total at 1 + a at 0
+4	jump 1
+5	freturn int total at 1
+create table t1( a int);
+insert into t1 values (1),(2),(3),(4);
+select f1(a) from t1;
+f1(a)
+10
+drop table t1;
+drop function f1;
diff --git a/mysql-test/t/custom_aggregate_functions.test b/mysql-test/t/custom_aggregate_functions.test
index 20fcc35f39f..61828c33e9e 100644
--- a/mysql-test/t/custom_aggregate_functions.test
+++ b/mysql-test/t/custom_aggregate_functions.test
@@ -771,3 +771,25 @@ drop function f4;
 drop function f5;
 drop function f6;
 drop function f7;
+
+--echo #
+--echo # MDEV-14564: Support FOR loop in stored aggregate functions
+--echo #
+
+DELIMITER /;
+CREATE aggregate FUNCTION f1 (a INT) RETURNS INT
+BEGIN
+ DECLARE total INT DEFAULT 0;
+ FOR GROUP NEXT ROW DO
+  SET total= total + a;
+ END FOR;
+ return total;
+END /
+
+delimiter ;/
+show function code f1;
+create table t1( a int);
+insert into t1 values (1),(2),(3),(4);
+select f1(a) from t1;
+drop table t1;
+drop function f1;
diff --git a/sql/sp_head.cc b/sql/sp_head.cc
index bcdef10506f..df5af9687fd 100644
--- a/sql/sp_head.cc
+++ b/sql/sp_head.cc
@@ -3715,7 +3715,7 @@ sp_instr_jump_if_not::exec_core(THD *thd, uint *nextp)
   else
   {
     res= 0;
-    if (! it->val_bool())
+    if (! it->val_bool() || thd->spcont->forced_error)
       *nextp = m_dest;
     else
       *nextp = m_ip+1;
@@ -4246,10 +4246,18 @@ sp_instr_agg_cfetch::execute(THD *thd, uint *nextp)
     thd->spcont->pause_state= FALSE;
     if (thd->server_status == SERVER_STATUS_LAST_ROW_SENT)
     {
-      my_message(ER_SP_FETCH_NO_DATA,
-                 ER_THD(thd, ER_SP_FETCH_NO_DATA), MYF(0));
-      res= -1;
-      thd->spcont->quit_func= TRUE;
+      if (m_dest)
+      {
+        thd->spcont->forced_error= TRUE;
+        *nextp= m_dest;
+      }
+      else
+      {
+        my_message(ER_SP_FETCH_NO_DATA,
+                   ER_THD(thd, ER_SP_FETCH_NO_DATA), MYF(0));
+        res= -1;
+        thd->spcont->quit_func= TRUE;
+      }
     }
     else
       *nextp= m_ip + 1;
diff --git a/sql/sp_head.h b/sql/sp_head.h
index 7e477544958..4467622e612 100644
--- a/sql/sp_head.h
+++ b/sql/sp_head.h
@@ -1843,8 +1843,8 @@ class sp_instr_agg_cfetch : public sp_instr
 
 public:
 
-  sp_instr_agg_cfetch(uint ip, sp_pcontext *ctx)
-    : sp_instr(ip, ctx){}
+  sp_instr_agg_cfetch(uint ip, sp_pcontext *ctx, uint dest)
+    : sp_instr(ip, ctx), m_dest(dest){}
 
   virtual ~sp_instr_agg_cfetch()
   {}
@@ -1852,6 +1852,9 @@ class sp_instr_agg_cfetch : public sp_instr
   virtual int execute(THD *thd, uint *nextp);
 
   virtual void print(String *str){};
+
+private:
+  uint m_dest;
 }; // class sp_instr_agg_cfetch : public sp_instr
 
 
diff --git a/sql/sp_rcontext.cc b/sql/sp_rcontext.cc
index 740941937e8..633bc9706ef 100644
--- a/sql/sp_rcontext.cc
+++ b/sql/sp_rcontext.cc
@@ -41,7 +41,7 @@ sp_rcontext::sp_rcontext(const sp_head *owner,
                          bool in_sub_stmt)
   :end_partial_result_set(false),
    pause_state(false), quit_func(false), instr_ptr(0),
-   m_sp(owner),
+   forced_error(false), m_sp(owner),
    m_root_parsing_ctx(root_parsing_ctx),
    m_var_table(NULL),
    m_return_value_fld(return_value_fld),
diff --git a/sql/sp_rcontext.h b/sql/sp_rcontext.h
index 0999271ebde..a26bbc6276e 100644
--- a/sql/sp_rcontext.h
+++ b/sql/sp_rcontext.h
@@ -181,6 +181,8 @@ class sp_rcontext : public Sql_alloc
   bool pause_state;
   bool quit_func;
   uint instr_ptr;
+  // Added for exiting for loops for custom aggregate functions
+  bool forced_error;
 
   /// The stored program for which this runtime context is created. Used for
   /// checking if correct runtime context is used for variable handling.
diff --git a/sql/sql_lex.cc b/sql/sql_lex.cc
index b40a44b7541..bdeb2db1374 100644
--- a/sql/sql_lex.cc
+++ b/sql/sql_lex.cc
@@ -5684,6 +5684,22 @@ bool LEX::sp_for_loop_condition(THD *thd, const Lex_for_loop_st &loop)
   return !expr || sp_while_loop_expression(thd, expr);
 }
 
+bool LEX::sp_for_loop_condition_agg_func(THD *thd, const Lex_for_loop_st &loop)
+{
+  Item *expr= (Item*) new (thd->mem_root) Item_bool(thd, "true", 1);
+  uint dest= sphead->instructions();
+  bool res= (!expr || sp_while_loop_expression(thd, expr));
+  return res || add_agg_fetch_instructions(dest);
+}
+
+bool LEX::add_agg_fetch_instructions(uint dest)
+{
+  sphead->m_flags|= sp_head::HAS_AGGREGATE_INSTR;
+  sp_instr_agg_cfetch *instr= new (thd->mem_root)
+                              sp_instr_agg_cfetch(sphead->instructions(),
+                                                  spcont, dest);
+  return (instr == NULL || sphead->add_instr(instr));
+}
 
 /**
   Generate the FOR LOOP condition code in its own lex
@@ -5698,6 +5714,15 @@ bool LEX::sp_for_loop_intrange_condition_test(THD *thd,
   return thd->lex->sphead->restore_lex(thd);
 }
 
+bool LEX::sp_for_loop_agg_func_condition_test(THD *thd,
+                                              const Lex_for_loop_st &loop)
+{
+  spcont->set_for_loop(loop);
+  sphead->reset_lex(thd);
+  if (thd->lex->sp_for_loop_condition_agg_func(thd, loop))
+    return true;
+  return thd->lex->sphead->restore_lex(thd);
+}
 
 bool LEX::sp_for_loop_cursor_condition_test(THD *thd,
                                             const Lex_for_loop_st &loop)
@@ -5734,6 +5759,14 @@ bool LEX::sp_for_loop_intrange_declarations(THD *thd, Lex_for_loop_st *loop,
   return false;
 }
 
+bool LEX::sp_for_loop_agg_func_declarations(THD *thd, Lex_for_loop_st *loop)
+{
+  loop->m_index= NULL;
+  loop->m_upper_bound= NULL;
+  loop->m_direction= 1;
+  loop->m_implicit_cursor= 0;
+  return false;
+}
 
 bool LEX::sp_for_loop_cursor_declarations(THD *thd,
                                           Lex_for_loop_st *loop,
@@ -5836,6 +5869,12 @@ bool LEX::sp_for_loop_intrange_finalize(THD *thd, const Lex_for_loop_st &loop)
   return sp_while_loop_finalize(thd);
 }
 
+bool LEX::sp_for_loop_agg_func_finalize(THD *thd, const Lex_for_loop_st &loop)
+{
+  // Generate a jump to the beginning of the loop
+  DBUG_ASSERT(this == thd->lex);
+  return sp_while_loop_finalize(thd);
+}
 
 bool LEX::sp_for_loop_cursor_finalize(THD *thd, const Lex_for_loop_st &loop)
 {
diff --git a/sql/sql_lex.h b/sql/sql_lex.h
index fac06233592..81ba6f384ae 100644
--- a/sql/sql_lex.h
+++ b/sql/sql_lex.h
@@ -3479,6 +3479,13 @@ struct LEX: public Query_tables_list
   bool sp_for_loop_intrange_condition_test(THD *thd, const Lex_for_loop_st &loop);
   bool sp_for_loop_intrange_finalize(THD *thd, const Lex_for_loop_st &loop);
 
+  /* special for loops (FOR GROUP NEXT ROW) for aggregate functions */
+  bool add_agg_fetch_instructions(uint dest);
+  bool sp_for_loop_condition_agg_func(THD *thd, const Lex_for_loop_st &loop);
+  bool sp_for_loop_agg_func_condition_test(THD *thd, const Lex_for_loop_st &loop);
+  bool sp_for_loop_agg_func_finalize(THD *thd, const Lex_for_loop_st &loop);
+  bool sp_for_loop_agg_func_declarations(THD *thd, Lex_for_loop_st *loop);
+
   /* Cursor FOR LOOP methods */
   bool sp_for_loop_cursor_declarations(THD *thd, Lex_for_loop_st *loop,
                                        const LEX_CSTRING *index,
@@ -3537,7 +3544,8 @@ struct LEX: public Query_tables_list
   bool sp_for_loop_condition_test(THD *thd, const Lex_for_loop_st &loop)
   {
     return loop.is_for_loop_cursor() ?
-           sp_for_loop_cursor_condition_test(thd, loop) :
+           (loop.m_index ? sp_for_loop_cursor_condition_test(thd, loop)
+                         : sp_for_loop_agg_func_condition_test(thd, loop)) :
            sp_for_loop_intrange_condition_test(thd, loop);
   }
 
@@ -3552,7 +3560,8 @@ struct LEX: public Query_tables_list
   bool sp_for_loop_finalize(THD *thd, const Lex_for_loop_st &loop)
   {
     return loop.is_for_loop_cursor() ?
-           sp_for_loop_cursor_finalize(thd, loop) :
+           (loop.m_index ? sp_for_loop_cursor_finalize(thd, loop) :
+                           sp_for_loop_agg_func_finalize(thd, loop)) :
            sp_for_loop_intrange_finalize(thd, loop);
   }
   /* End of FOR LOOP methods */
diff --git a/sql/sql_yacc.yy b/sql/sql_yacc.yy
index 158c09dded1..519d53bec4a 100644
--- a/sql/sql_yacc.yy
+++ b/sql/sql_yacc.yy
@@ -4052,7 +4052,7 @@ sp_proc_stmt_fetch:
             lex->sphead->m_flags|= sp_head::HAS_AGGREGATE_INSTR;
             sp_instr_agg_cfetch *i=
               new (thd->mem_root) sp_instr_agg_cfetch(sp->instructions(),
-                                                      lex->spcont);
+                                                      lex->spcont,0);
             if (i == NULL ||
                 sp->add_instr(i))
               MYSQL_YYABORT;
@@ -4385,6 +4385,10 @@ sp_for_loop_index_and_bounds:
             if (Lex->sp_for_loop_declarations(thd, &$$, &$1, $2))
               MYSQL_YYABORT;
           }
+          | GROUP_SYM NEXT_SYM ROW_SYM
+          {
+            Lex->sp_for_loop_agg_func_declarations(thd, &$$);
+          }
         ;
 
 sp_for_loop_bounds:


More information about the commits mailing list