1
0
Fork 0
httpserver/middleware/db/db.go

136 Zeilen
3.4 KiB
Go

package db
import (
"context"
"io/fs"
"sync"
"github.com/gin-gonic/gin"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/jackc/tern/migrate"
"github.com/phuslu/log"
"go.sebtobie.de/httpserver/funcs"
"go.sebtobie.de/httpserver/middleware"
)
// ContextKey is the key that is used in a gin.Context to get the Middleware
const ContextKey string = "db"
// ConnGet is an function that returns an connection instance.
type ConnGet func(string) *pgxpool.Conn
var _ ConnGet = NewMiddleware().GetConn
var _ middleware.Middleware = &Middleware{}
// Middleware return a handler that sets the db into the context of every request.
// uri is an url in the form dbtype:connectargs
type Middleware struct {
databases map[string]*pgxpool.Pool
lock sync.Mutex
}
//NewMiddleware return an initialized Middleware Object.
func NewMiddleware() *Middleware {
return &Middleware{
databases: make(map[string]*pgxpool.Pool),
}
}
// AddDB adds an db connection to the middleware.
func (m *Middleware) AddDB(name, uri string) (err error) {
m.lock.Lock()
if m.databases == nil {
m.databases = map[string]*pgxpool.Pool{}
}
m.lock.Unlock()
var (
db *pgxpool.Pool
connobject *pgxpool.Config
)
if err != nil {
log.Error().Err(err).Msg("Could not open the database")
return err
}
connobject, err = pgxpool.ParseConfig(uri)
if err != nil {
log.Error().Err(err).Msg("Could not open the database")
return err
}
connobject.ConnConfig.Logger = funcs.PGXLogger{}
db, err = pgxpool.ConnectConfig(context.TODO(), connobject)
if err != nil {
log.Error().Err(err).Msg("Could not open the database")
return err
}
m.lock.Lock()
defer m.lock.Unlock()
if olddb, found := m.databases[name]; found {
olddb.Close()
}
m.databases[name] = db
return
}
func (m *Middleware) getconn(name string) *pgxpool.Conn {
if db, found := m.databases[name]; found {
conn, err := db.Acquire(context.TODO())
if err != nil {
log.Error().Err(err).Msgf("Could not get the connection from the pool %s", name)
return nil
}
return conn
}
return nil
}
// GetConn returns a connection from the specified database or if not found one of the default database.
func (m *Middleware) GetConn(name string) *pgxpool.Conn {
m.lock.Lock()
defer m.lock.Unlock()
if conn := m.getconn(name); conn != nil {
return conn
}
conn := m.getconn("default")
return conn
}
// Gin is the Entrypoint for Gin.
func (m *Middleware) Gin(c *gin.Context) {
c.Set(ContextKey, m.GetConn)
}
// Teardown closes the DBConnection
func (m *Middleware) Teardown() {
for name, pool := range m.databases {
log.Info().Msgf("Starting to close databasepool %s", name)
pool.Close()
log.Info().Msgf("Closed Databasepool %s", name)
}
}
func logmigrations(version int32, name string, dir string, sql string) {
sql = ""
log.Info().Int32("version", version).Str("name", name).Str("dir", dir).Msgf("Running Migration %s with version %d in %sward direction", name, version, dir)
}
// SetupMigrator sets up the migrator to migrate the database.
func SetupMigrator(prefix string, connection *pgx.Conn, migrations fs.FS) (mig *migrate.Migrator, err error) {
mig, err = migrate.NewMigratorEx(
context.TODO(),
connection,
"version",
&migrate.MigratorOptions{
DisableTx: false,
MigratorFS: &iofsMigratorFS{fsys: migrations},
},
)
if err != nil {
log.Error().Err(err).Msg("Error while creating the migrator")
return
}
mig.OnStart = logmigrations
mig.Data["prefix"] = prefix
return
}