@@ -14,10 +14,13 @@ import (
1414
1515 "github.com/github/gh-ost/go/base"
1616 "github.com/github/gh-ost/go/binlog"
17- "github.com/github/gh-ost/go/mysql"
1817 "github.com/github/gh-ost/go/sql"
1918
20- "github.com/openark/golib/log"
19+ "context"
20+ "database/sql/driver"
21+
22+ "github.com/github/gh-ost/go/mysql"
23+ drivermysql "github.com/go-sql-driver/mysql"
2124 "github.com/openark/golib/sqlutils"
2225)
2326
@@ -1207,13 +1210,19 @@ func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) []*dmlB
12071210// ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table
12081211func (this * Applier ) ApplyDMLEventQueries (dmlEvents [](* binlog.BinlogDMLEvent )) error {
12091212 var totalDelta int64
1213+ ctx := context .TODO ()
12101214
12111215 err := func () error {
1212- tx , err := this .db .Begin ( )
1216+ conn , err := this .db .Conn ( ctx )
12131217 if err != nil {
12141218 return err
12151219 }
1220+ defer conn .Close ()
12161221
1222+ tx , err := conn .BeginTx (ctx , nil )
1223+ if err != nil {
1224+ return err
1225+ }
12171226 rollback := func (err error ) error {
12181227 tx .Rollback ()
12191228 return err
@@ -1225,34 +1234,49 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent))
12251234 if _ , err := tx .Exec (sessionQuery ); err != nil {
12261235 return rollback (err )
12271236 }
1228- multiArgs := []interface {}{}
1237+ rowDeltas := make ([]int64 , 0 , len (dmlEvents ))
1238+ multiArgs := []driver.NamedValue {}
12291239 var multiQueryBuilder strings.Builder
12301240 for _ , dmlEvent := range dmlEvents {
12311241 for _ , buildResult := range this .buildDMLEventQuery (dmlEvent ) {
12321242 if buildResult .err != nil {
1233- return buildResult .err
1243+ return rollback ( buildResult .err )
12341244 }
1235- multiArgs = append (multiArgs , buildResult .args ... )
1245+ for _ , arg := range buildResult .args {
1246+ multiArgs = append (multiArgs , driver.NamedValue {Value : driver .Value (arg )})
1247+ }
1248+ rowDeltas = append (rowDeltas , buildResult .rowsDelta )
12361249 multiQueryBuilder .WriteString (buildResult .query )
12371250 multiQueryBuilder .WriteString (";\n " )
12381251 }
12391252 }
1240- // TODO: get rows affected from each query in multi statement
1241- log .Warningf ("error getting rows affected from DML event query: %s. i'm going to assume that the DML affected a single row, but this may result in inaccurate statistics" , err )
1242- _ , err = tx .Exec (multiQueryBuilder .String (), multiArgs ... )
1243- if err != nil {
1244- err = fmt .Errorf ("%w; query=%s; args=%+v" , err , multiQueryBuilder .String (), multiArgs )
1245- return rollback (err )
1246- }
1247- // rowsAffected, err := result.RowsAffected()
1248- // if err != nil {
1249- // log.Warningf("error getting rows affected from DML event query: %s. i'm going to assume that the DML affected a single row, but this may result in inaccurate statistics", err)
1250- // rowsAffected = 1
1251- // }
1252- // each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1).
1253- // multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event
1254- // totalDelta += buildResult.rowsDelta * rowsAffected
12551253
1254+ //this.migrationContext.Log.Infof("Executing query: %s, args: %+v", multiQueryBuilder.String(), multiArgs)
1255+ execErr := conn .Raw (func (driverConn any ) error {
1256+ ex , ok := driverConn .(driver.ExecerContext )
1257+ if ! ok {
1258+ return fmt .Errorf ("could not cast driverConn to ExecerContext" )
1259+ }
1260+ res , err := ex .ExecContext (ctx , multiQueryBuilder .String (), multiArgs )
1261+ if err != nil {
1262+ err = fmt .Errorf ("%w; query=%s; args=%+v" , err , multiQueryBuilder .String (), multiArgs )
1263+ this .migrationContext .Log .Errorf ("Error exec: %+v" , err )
1264+ return err
1265+ }
1266+ mysqlRes , ok := res .(drivermysql.Result )
1267+ if ! ok {
1268+ return fmt .Errorf ("Could not cast %+v to mysql.Result" , res )
1269+ }
1270+ // each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1).
1271+ // multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event
1272+ for i , rowsAffected := range mysqlRes .AllRowsAffected () {
1273+ totalDelta += rowDeltas [i ] * rowsAffected
1274+ }
1275+ return nil
1276+ })
1277+ if execErr != nil {
1278+ return rollback (execErr )
1279+ }
12561280 if err := tx .Commit (); err != nil {
12571281 return err
12581282 }
0 commit comments