121 Zeilen
3.0 KiB
Go
121 Zeilen
3.0 KiB
Go
|
package db
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"sync"
|
||
|
|
||
|
"github.com/gin-gonic/gin"
|
||
|
"github.com/jackc/pgx"
|
||
|
"github.com/jackc/pgx/v4/pgxpool"
|
||
|
"github.com/jackc/tern/migrate"
|
||
|
"github.com/phuslu/log"
|
||
|
"go.sebtobie.de/httpserver/funcs"
|
||
|
"go.sebtobie.de/httpserver/middleware"
|
||
|
)
|
||
|
|
||
|
// ContextKey is the key that is used in a gin.Context to get the Middleware
|
||
|
const ContextKey string = "db"
|
||
|
|
||
|
var _ middleware.Middleware = &Middleware{}
|
||
|
|
||
|
// Middleware return a handler that sets the db into the context of every request.
|
||
|
// uri is an url in the form dbtype:connectargs
|
||
|
type Middleware struct {
|
||
|
databases map[string]*pgxpool.Pool
|
||
|
lock sync.Mutex
|
||
|
}
|
||
|
|
||
|
//NewMiddleware return an initialized Middleware Object.
|
||
|
func NewMiddleware() *Middleware {
|
||
|
return &Middleware{
|
||
|
databases: make(map[string]*pgxpool.Pool),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// AddDB adds an db connection to the middleware.
|
||
|
func (m *Middleware) AddDB(name, uri string) (err error) {
|
||
|
m.lock.Lock()
|
||
|
if m.databases == nil {
|
||
|
m.databases = map[string]*pgxpool.Pool{}
|
||
|
}
|
||
|
m.lock.Unlock()
|
||
|
var (
|
||
|
db *pgxpool.Pool
|
||
|
connobject *pgxpool.Config
|
||
|
)
|
||
|
if err != nil {
|
||
|
log.Error().Err(err).Msg("Could not open the database")
|
||
|
return err
|
||
|
}
|
||
|
connobject, err = pgxpool.ParseConfig(uri)
|
||
|
if err != nil {
|
||
|
log.Error().Err(err).Msg("Could not open the database")
|
||
|
return err
|
||
|
}
|
||
|
connobject.ConnConfig.Logger = funcs.PGXLogger{}
|
||
|
db, err = pgxpool.ConnectConfig(context.TODO(), connobject)
|
||
|
if err != nil {
|
||
|
log.Error().Err(err).Msg("Could not open the database")
|
||
|
return err
|
||
|
}
|
||
|
m.lock.Lock()
|
||
|
defer m.lock.Unlock()
|
||
|
if olddb, found := m.databases[name]; found {
|
||
|
olddb.Close()
|
||
|
}
|
||
|
m.databases[name] = db
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// GetDB returns a pgx.Conn Object or nil if the name is not found.
|
||
|
func (m *Middleware) GetDB(name string) *pgxpool.Conn {
|
||
|
m.lock.Lock()
|
||
|
defer m.lock.Unlock()
|
||
|
if db, found := m.databases[name]; found {
|
||
|
conn, err := db.Acquire(context.TODO())
|
||
|
if err != nil {
|
||
|
log.Error().Err(err).Msg("Could not get a connection from the pool")
|
||
|
return nil
|
||
|
}
|
||
|
return conn
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Gin is the Entrypoint for Gin.
|
||
|
func (m *Middleware) Gin(c *gin.Context) {
|
||
|
c.Set(ContextKey, m)
|
||
|
}
|
||
|
|
||
|
// Teardown closes the DBConnection
|
||
|
func (m *Middleware) Teardown() {
|
||
|
for name, pool := range m.databases {
|
||
|
log.Info().Msgf("Starting to close databasepool %s", name)
|
||
|
pool.Close()
|
||
|
log.Info().Msgf("Closed Databasepool %s", name)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func logmigrations(version int32, name string, dir string, sql string) {
|
||
|
log.Info().Int32("version", version).Str("name", name).Str("dir", dir).Msgf("Running Migration %s with version %i in %sward direction", name, version, dir)
|
||
|
}
|
||
|
|
||
|
// SetupMigrator sets up the migrator to migrate the database.
|
||
|
func SetupMigrator(prefix string, connection *pgx.Conn, migrations migrate.MigratorFS) (mig *migrate.Migrator, err error) {
|
||
|
mig, err = migrate.NewMigratorEx(
|
||
|
context.TODO(),
|
||
|
connection,
|
||
|
"version",
|
||
|
&migrate.MigratorOptions{
|
||
|
DisableTX: false,
|
||
|
MigratorFS: migrations,
|
||
|
},
|
||
|
)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
mig.OnStart = logmigrations
|
||
|
mig.Data["prefix"] = prefix
|
||
|
return
|
||
|
}
|