package db import ( "context" "io/fs" "sync" "github.com/gin-gonic/gin" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/tern/migrate" "github.com/phuslu/log" "go.sebtobie.de/httpserver/funcs" "go.sebtobie.de/httpserver/middleware" ) // ContextKey is the key that is used in a gin.Context to get the Middleware const ContextKey string = "db" // ConnGet is an function that returns an connection instance. type ConnGet func(string) *pgxpool.Conn var _ ConnGet = NewMiddleware().GetConn var _ middleware.Middleware = &Middleware{} // Middleware return a handler that sets the db into the context of every request. // uri is an url in the form dbtype:connectargs type Middleware struct { databases map[string]*pgxpool.Pool lock sync.Mutex } //NewMiddleware return an initialized Middleware Object. func NewMiddleware() *Middleware { return &Middleware{ databases: make(map[string]*pgxpool.Pool), } } // AddDB adds an db connection to the middleware. func (m *Middleware) AddDB(name, uri string) (err error) { m.lock.Lock() if m.databases == nil { m.databases = map[string]*pgxpool.Pool{} } m.lock.Unlock() var ( db *pgxpool.Pool connobject *pgxpool.Config ) if err != nil { log.Error().Err(err).Msg("Could not open the database") return err } connobject, err = pgxpool.ParseConfig(uri) if err != nil { log.Error().Err(err).Msg("Could not open the database") return err } connobject.ConnConfig.Logger = funcs.PGXLogger{} db, err = pgxpool.ConnectConfig(context.TODO(), connobject) if err != nil { log.Error().Err(err).Msg("Could not open the database") return err } m.lock.Lock() defer m.lock.Unlock() if olddb, found := m.databases[name]; found { olddb.Close() } m.databases[name] = db return } func (m *Middleware) getconn(name string) *pgxpool.Conn { if db, found := m.databases[name]; found { conn, err := db.Acquire(context.TODO()) if err != nil { log.Error().Err(err).Msgf("Could not get the connection from the pool %s", name) return nil } return conn } return nil } // GetConn returns a connection from the specified database or if not found one of the default database. func (m *Middleware) GetConn(name string) *pgxpool.Conn { m.lock.Lock() defer m.lock.Unlock() if conn := m.getconn(name); conn != nil { return conn } conn := m.getconn("default") return conn } // Gin is the Entrypoint for Gin. func (m *Middleware) Gin(c *gin.Context) { c.Set(ContextKey, m.GetConn) } // Teardown closes the DBConnection func (m *Middleware) Teardown() { for name, pool := range m.databases { log.Info().Msgf("Starting to close databasepool %s", name) pool.Close() log.Info().Msgf("Closed Databasepool %s", name) } } func logmigrations(version int32, name string, dir string, sql string) { sql = "" log.Info().Int32("version", version).Str("name", name).Str("dir", dir).Msgf("Running Migration %s with version %d in %sward direction", name, version, dir) } // SetupMigrator sets up the migrator to migrate the database. func SetupMigrator(prefix string, connection *pgx.Conn, migrations fs.FS) (mig *migrate.Migrator, err error) { mig, err = migrate.NewMigratorEx( context.TODO(), connection, "version", &migrate.MigratorOptions{ DisableTx: false, MigratorFS: &iofsMigratorFS{fsys: migrations}, }, ) if err != nil { log.Error().Err(err).Msg("Error while creating the migrator") return } mig.OnStart = logmigrations mig.Data["prefix"] = prefix return }