@@ -49,12 +49,13 @@ type Migration struct {
4949 Down func (ctx context.Context , txn * sql.Tx ) error
5050}
5151
52- // Migrator
52+ // Migrator contains fields required to run migrations.
5353type Migrator struct {
5454 db * sql.DB
5555 migrations []Migration
5656 knownMigrations map [string ]struct {}
5757 mutex * sync.Mutex
58+ insertStmt * sql.Stmt
5859}
5960
6061// NewMigrator creates a new DB migrator.
@@ -82,42 +83,50 @@ func (m *Migrator) AddMigrations(migrations ...Migration) {
8283
8384// Up executes all migrations in order they were added.
8485func (m * Migrator ) Up (ctx context.Context ) error {
85- var (
86- err error
87- dendriteVersion = internal .VersionString ()
88- )
8986 // ensure there is a table for known migrations
9087 executedMigrations , err := m .ExecutedMigrations (ctx )
9188 if err != nil {
9289 return fmt .Errorf ("unable to create/get migrations: %w" , err )
9390 }
94-
91+ // ensure we close the insert statement, as it's not needed anymore
92+ defer m .close ()
9593 return WithTransaction (m .db , func (txn * sql.Tx ) error {
9694 for i := range m .migrations {
97- now := time .Now ().UTC ().Format (time .RFC3339 )
9895 migration := m .migrations [i ]
9996 // Skip migration if it was already executed
10097 if _ , ok := executedMigrations [migration .Version ]; ok {
10198 continue
10299 }
103100 logrus .Debugf ("Executing database migration '%s'" , migration .Version )
104- err = migration . Up ( ctx , txn )
105- if err != nil {
101+
102+ if err = migration . Up ( ctx , txn ); err != nil {
106103 return fmt .Errorf ("unable to execute migration '%s': %w" , migration .Version , err )
107104 }
108- _ , err = txn .ExecContext (ctx , insertVersionSQL ,
109- migration .Version ,
110- now ,
111- dendriteVersion ,
112- )
113- if err != nil {
105+ if err = m .insertMigration (ctx , txn , migration .Version ); err != nil {
114106 return fmt .Errorf ("unable to insert executed migrations: %w" , err )
115107 }
116108 }
117109 return nil
118110 })
119111}
120112
113+ func (m * Migrator ) insertMigration (ctx context.Context , txn * sql.Tx , migrationName string ) error {
114+ if m .insertStmt == nil {
115+ stmt , err := m .db .Prepare (insertVersionSQL )
116+ if err != nil {
117+ return fmt .Errorf ("unable to prepare insert statement: %w" , err )
118+ }
119+ m .insertStmt = stmt
120+ }
121+ stmt := TxStmtContext (ctx , txn , m .insertStmt )
122+ _ , err := stmt .ExecContext (ctx ,
123+ migrationName ,
124+ time .Now ().Format (time .RFC3339 ),
125+ internal .VersionString (),
126+ )
127+ return err
128+ }
129+
121130// ExecutedMigrations returns a map with already executed migrations in addition to creating the
122131// migrations table, if it doesn't exist.
123132func (m * Migrator ) ExecutedMigrations (ctx context.Context ) (map [string ]struct {}, error ) {
@@ -146,19 +155,20 @@ func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]struct{},
146155// inserts a migration given their name to the database.
147156// This should only be used when manually inserting migrations.
148157func InsertMigration (ctx context.Context , db * sql.DB , migrationName string ) error {
149- _ , err := db .ExecContext (ctx , createDBMigrationsSQL )
158+ m := NewMigrator (db )
159+ defer m .close ()
160+ existingMigrations , err := m .ExecutedMigrations (ctx )
150161 if err != nil {
151- return fmt . Errorf ( "unable to create db_migrations: %w" , err )
162+ return err
152163 }
153- _ , err = db .ExecContext (ctx , insertVersionSQL ,
154- migrationName ,
155- time .Now ().Format (time .RFC3339 ),
156- internal .VersionString (),
157- )
158- // If the migration was already executed, we'll get a unique constraint error,
159- // return nil instead, to avoid unnecessary logging.
160- if IsUniqueConstraintViolationErr (err ) {
164+ if _ , ok := existingMigrations [migrationName ]; ok {
161165 return nil
162166 }
163- return err
167+ return m .insertMigration (ctx , nil , migrationName )
168+ }
169+
170+ func (m * Migrator ) close () {
171+ if m .insertStmt != nil {
172+ internal .CloseAndLogIfError (context .Background (), m .insertStmt , "unable to close insert statement" )
173+ }
164174}
0 commit comments