// Package db is an middleware that manages multiple database pools and provides applications with an way to access the database package db import ( "context" "errors" "fmt" "io/fs" "sync" "github.com/gin-gonic/gin" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/tern/v2/migrate" "github.com/rs/zerolog/log" "go.sebtobie.de/httpserver/middleware" uuid "github.com/jackc/pgx-gofrs-uuid" ) // ContextKey is the key that is used in a gin.Context to get the Middleware const ContextKey string = "db" // ConnGet is an function that returns an connection instance. type ConnGet func(string) *pgxpool.Conn var _ ConnGet = NewMiddleware().GetConn var _ middleware.Middleware = &Middleware{} var _ middleware.PostSetupMiddleware = &Middleware{} // GetConnection is an simple helper function that returns an connection to the db func GetConnection(c *gin.Context, db string) (*pgxpool.Conn, error) { if co, ok := c.Get(ContextKey); ok { if cg, ok := co.(ConnGet); ok { return cg(db), nil } return nil, fmt.Errorf("Failed to convert the method. %T != ConnGet", co) } return nil, errors.New("No db.Middleware set up. ") } // Middleware return a handler that sets the db into the context of every request. // uri is an url in the form dbtype:connectargs type Middleware struct { databases map[string]*pgxpool.Pool lock sync.Mutex } // NewMiddleware return an initialized Middleware Object. func NewMiddleware() *Middleware { return &Middleware{ databases: make(map[string]*pgxpool.Pool), } } // AddDB adds an db connection to the middleware. func (m *Middleware) AddDB(name, uri string) (err error) { m.lock.Lock() if m.databases == nil { m.databases = map[string]*pgxpool.Pool{} } m.lock.Unlock() var ( db *pgxpool.Pool ) db, err = pgxpool.New(context.TODO(), uri) if err != nil { log.Error().Err(err).Msg("Could not open the database") return err } db.Config().AfterConnect = func(_ context.Context, c *pgx.Conn) error { uuid.Register(c.TypeMap()) return nil } m.lock.Lock() defer m.lock.Unlock() if olddb, found := m.databases[name]; found { olddb.Close() } m.databases[name] = db return } func (m *Middleware) getconn(name string) *pgxpool.Conn { if db, found := m.databases[name]; found { conn, err := db.Acquire(context.TODO()) if err != nil { log.Error().Err(err).Msgf("Could not get the connection from the pool %s", name) return nil } return conn } return nil } // GetConn returns a connection from the specified database or if not found one of the default database. func (m *Middleware) GetConn(name string) *pgxpool.Conn { m.lock.Lock() defer m.lock.Unlock() if conn := m.getconn(name); conn != nil { return conn } conn := m.getconn("default") return conn } // Gin is the Entrypoint for Gin. func (m *Middleware) Gin(c *gin.Context) { c.Set(ContextKey, ConnGet(m.GetConn)) } // Setup adds the connections from the configfile into the middleware func (m *Middleware) Setup(mc middleware.Config) { for key, value := range mc { dsn := value.(string) err := m.AddDB(key, dsn) if err != nil { log.Error().Err(err).Msg("Failed to parse the config") } } } // Defaults returns an default config for connections func (*Middleware) Defaults() middleware.Config { return map[string]any{ "default": "host=/run/postgresql port=5432 dbname=httpserver", } } // Teardown closes the DBConnection func (m *Middleware) Teardown() { for name, pool := range m.databases { log.Info().Msgf("Starting to close databasepool %s", name) pool.Close() log.Info().Msgf("Closed Databasepool %s", name) } } // PostSetup is an function for getting the migrations of the site func (m *Middleware) PostSetup(sites []any) (err error) { m.lock.Lock() defer m.lock.Unlock() var ( conn *pgxpool.Conn mig *migrate.Migrator db *pgxpool.Pool ) for _, s := range sites { if site, ok := s.(MigrationSite); ok { db, ok = m.databases[site.Database()] if !ok { return fmt.Errorf("Failed to get the database. The Databasepool %s does not exist", site.Database()) } conn, err = db.Acquire(context.TODO()) if err != nil { return } defer conn.Release() mig, err = site.Migrations(conn.Conn()) err = mig.Migrate(context.TODO()) if err != nil { return } if poolsite, ok := s.(PoolSite); ok { poolsite.Pool(m.databases[site.Database()]) } } } return } func logmigrations(version int32, name, dir string, sql string) { log.Info().Int32("version", version).Str("name", name).Str("dir", dir).Msgf("Running Migration %s with version %d in %sward direction", name, version, dir) } // MigrationSite is an interface for Sites that use an database. // this enables sites to provide database migrations. type MigrationSite interface { Database() string Migrations(*pgx.Conn) (*migrate.Migrator, error) } // PoolSite is an interface for site that need access to the pool outside of requests type PoolSite interface { MigrationSite Pool(*pgxpool.Pool) } // SetupMigrator sets up the migrator to migrate the database. func SetupMigrator(prefix string, connection *pgx.Conn, migrations fs.FS) (mig *migrate.Migrator, err error) { mig, err = migrate.NewMigratorEx( context.TODO(), connection, prefix+"version", &migrate.MigratorOptions{ DisableTx: false, }, ) if err != nil { log.Error().Err(err).Msg("Error while creating the migrator") return } mig.OnStart = logmigrations mig.Data["prefix"] = prefix err = mig.LoadMigrations(migrations) if err != nil { log.Error().Err(err).Msg("Error while loading migrations") return } log.Trace().Interface("migrations", mig.Migrations).Interface("data", mig.Data).Err(err).Send() return }