diff --git a/go.mod b/go.mod index d0ceb1c..1cc6a5d 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,6 @@ require ( github.com/mattn/go-isatty v0.0.14 // indirect github.com/mitchellh/copystructure v1.1.1 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect - github.com/pelletier/go-toml v1.9.4 github.com/phuslu/log v1.0.75 github.com/pkg/errors v0.9.1 // indirect github.com/shopspring/decimal v1.2.0 // indirect diff --git a/go.sum b/go.sum index 7d299c4..1cd88fe 100644 --- a/go.sum +++ b/go.sum @@ -221,8 +221,6 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM= -github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/phuslu/log v1.0.75 h1:2Qcqgwo1sOsvj7QIuclIS92hmWxIISI2+XskYM1Nw2A= github.com/phuslu/log v1.0.75/go.mod h1:kzJN3LRifrepxThMjufQwS7S35yFAB+jAV1qgA7eBW4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/http.go b/http.go index 44c5678..10bcf5d 100644 --- a/http.go +++ b/http.go @@ -9,13 +9,18 @@ import ( "sync" "github.com/gin-gonic/gin" - "github.com/pelletier/go-toml" "github.com/phuslu/log" "go.sebtobie.de/httpserver/auth" "go.sebtobie.de/httpserver/menus" "go.sebtobie.de/httpserver/templates" ) +func init() { + gin.DebugPrintRouteFunc = func(httpMethod, absolutePath, handlerName string, nuHandlers int) { + log.Trace().Msgf("%-4s(%02d): %-20s %s", httpMethod, nuHandlers, absolutePath, handlerName) + } +} + // Config that is used to map the toml config to the settings that are used. type Config struct { Addr []string @@ -23,6 +28,7 @@ type Config struct { TLSconfig *tls.Config Certfile string Keyfile string + Sites map[string]SiteConfig } // MarshalObject adds the information over the object to the *log.Entry @@ -38,21 +44,36 @@ var _ log.ObjectMarshaler = &Config{} // Server is an wrapper for the *http.Server and *gin.Engine type Server struct { - http *http.Server - conf *Config - mrouter map[string]*gin.Engine - config *toml.Tree - sites []Site + http *http.Server + Conf *Config + mrouter map[string]*gin.Engine + sites []struct { + config SiteConfig + site Site + } menu []menus.Menu template template.Template NotFoundHandler http.Handler routines sync.WaitGroup } -func init() { - gin.DebugPrintRouteFunc = func(httpMethod, absolutePath, handlerName string, nuHandlers int) { - log.Trace().Msgf("%-4s(%02d): %-20s %s", httpMethod, nuHandlers, absolutePath, handlerName) +// CreateServer creates an server that can be run in a coroutine. +func CreateServer() *Server { + log.Info().Msg("Redirect logging output to phuslu/log") + gin.DefaultErrorWriter = log.DefaultLogger.Std(log.ErrorLevel, log.Context{}, "GIN", 0).Writer() + gin.DefaultWriter = log.DefaultLogger.Std(log.DebugLevel, log.Context{}, "GIN", 0).Writer() + log.Info().Msg("Creating HTTP-Server") + var server = &Server{ + Conf: &Config{}, + mrouter: map[string]*gin.Engine{}, } + server.http = &http.Server{ + ErrorLog: log.DefaultLogger.Std(log.ErrorLevel, log.Context{}, "", 0), + Handler: http.HandlerFunc(server.DomainRouter), + } + server.NotFoundHandler = http.NotFoundHandler() + server.menu = []menus.Menu{} + return server } // runPort runs a listener on the port. his enables th server to serve more than a address. @@ -68,7 +89,7 @@ func (s *Server) runPort(address string, tls bool) { log.Error().Err(err).Msgf("failed to open socket on %s", address) } if tls { - err = s.http.ServeTLS(socket, s.conf.Certfile, s.conf.Keyfile) + err = s.http.ServeTLS(socket, s.Conf.Certfile, s.Conf.Keyfile) } else { err = s.http.Serve(socket) } @@ -82,14 +103,15 @@ func (s *Server) runPort(address string, tls bool) { // it blocks until all ports are closed. func (s *Server) StartServer() { log.Info().Msg("Starting server") + s.http.TLSConfig = s.Conf.TLSconfig var err error - if s.conf.Certfile != "" && s.conf.Keyfile != "" { - for _, addr := range s.conf.TLSAddr { + if s.Conf.Certfile != "" && s.Conf.Keyfile != "" { + for _, addr := range s.Conf.TLSAddr { s.routines.Add(1) go s.runPort(addr, true) } } - for _, addr := range s.conf.Addr { + for _, addr := range s.Conf.Addr { go s.runPort(addr, false) } if err != http.ErrServerClosed { @@ -127,32 +149,6 @@ func (s *Server) DomainRouter(w http.ResponseWriter, r *http.Request) { s.NotFoundHandler.ServeHTTP(w, r) } -// CreateServer creates an server that can be run in a coroutine. -func CreateServer(config *toml.Tree) *Server { - log.Info().Msg("Redirect logging output to phuslu/log") - gin.DefaultErrorWriter = log.DefaultLogger.Std(log.ErrorLevel, log.Context{}, "GIN", 0).Writer() - gin.DefaultWriter = log.DefaultLogger.Std(log.DebugLevel, log.Context{}, "GIN", 0).Writer() - log.Info().Msg("Creating HTTP-Server") - var server = &Server{ - conf: &Config{ - Addr: []string{"127.0.0.1:8080"}, - }, - mrouter: map[string]*gin.Engine{}, - } - if err := config.Unmarshal(server.conf); err != nil { - log.Error().Msg("Problem mapping config to Configstruct") - } - log.Debug().EmbedObject(server.conf).Msg("Config") - server.http = &http.Server{ - ErrorLog: log.DefaultLogger.Std(log.ErrorLevel, log.Context{}, "", 0), - Handler: http.HandlerFunc(server.DomainRouter), - TLSConfig: server.conf.TLSconfig, - } - server.NotFoundHandler = http.NotFoundHandler() - server.menu = []menus.Menu{} - return server -} - // Use installs the middleware into the router. // The Middleware must be able to detect multiple calls byy itself. Deduplication is not performed. func (s *Server) Use(m ...gin.HandlerFunc) { @@ -164,33 +160,32 @@ func (s *Server) Use(m ...gin.HandlerFunc) { // Stop Shuts the Server down func (s *Server) Stop(ctx context.Context) { log.Info().Err(s.http.Shutdown(ctx)).Msg("Server Shut down.") - for _, site := range s.sites { - site.Teardown() + for _, s := range s.sites { + s.site.Teardown() } } -// Site is an Interface to abstract the modularized group of pages. -// The Middleware must be able to detect multiple calls byy itself. Deduplication is not performed. -type Site interface { - Init(*gin.RouterGroup) - Teardown() -} - func (s *Server) menus() []menus.Menu { return s.menu } // RegisterSite adds an site to the engine as its own grouo -func (s *Server) RegisterSite(domain, path string, site Site) { +func (s *Server) RegisterSite(cfg string, site Site) { var router *gin.Engine var found bool - if router, found = s.mrouter[domain]; !found { + var config = s.Conf.Sites[cfg] + if err := site.Setup(config); err != nil { + log.Error().Err(err).Msg("Site failed to load the config") + return + } + + if router, found = s.mrouter[config.Domain()]; !found { var authhf auth.AuthenticationHandler router = gin.New() mw := []gin.HandlerFunc{ func(c *gin.Context) { c.Set(Menus, s.menus) - c.Set(Domain, domain) + c.Set(Domain, config.Domain()) }, } if authhf, found = site.(auth.AuthenticationHandler); !found { @@ -198,18 +193,24 @@ func (s *Server) RegisterSite(domain, path string, site Site) { } mw = append(mw, func(c *gin.Context) { authhf.Account(c) }) router.Use(mw...) - s.mrouter[domain] = router + s.mrouter[config.Domain()] = router } - site.Init(router.Group(path)) - s.sites = append(s.sites, site) + site.Init(router.Group(config.Path())) + s.sites = append(s.sites, struct { + config SiteConfig + site Site + }{ + config: config, + site: site, + }) if ms, ok := site.(menus.MenuSite); ok { - menus := ms.Menu(domain) + menus := ms.Menu(config.Domain()) log.Debug().Msgf("%d menus are added", len(menus)) s.menu = append(s.menu, menus...) } if ts, ok := site.(templates.TemplateSite); ok { templates := ts.Templates() - log.Debug().Msgf("Templates for %s%s are added", domain, path) + log.Debug().Msgf("Templates for %s%s are added", config.Domain(), config.Path()) s.template.AddParseTree(templates.Name(), templates.Tree) s.template.Funcs(ts.Funcs()) } diff --git a/modules/saml/funcs.go b/modules/saml/funcs.go index d8b55cc..f5b8324 100644 --- a/modules/saml/funcs.go +++ b/modules/saml/funcs.go @@ -16,7 +16,7 @@ import ( func initcert(file string, verify func(interface{}) bool) (key interface{}, err error) { var blocks []*pem.Block if file == "" { - err = errors.New("SPPrivatekey empty") + err = fmt.Errorf("File %s is empty", file) return } blocks, err = loadcerts(file) @@ -87,7 +87,9 @@ func attributeStatementstomap(a []saml.AttributeStatement) map[string][]string { var output = map[string][]string{} for _, b := range a { for _, c := range b.Attributes { - output[c.FriendlyName] = []string{} + if _, ok := output[c.FriendlyName]; !ok { + output[c.FriendlyName] = []string{} + } for _, d := range c.Values { output[c.FriendlyName] = append(output[c.FriendlyName], d.Value) } diff --git a/modules/saml/saml.go b/modules/saml/saml.go index 2e74d75..a2926ab 100644 --- a/modules/saml/saml.go +++ b/modules/saml/saml.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rsa" "crypto/x509" + "errors" "fmt" "net/http" "net/url" @@ -12,20 +13,24 @@ import ( "github.com/crewjam/saml" "github.com/crewjam/saml/samlsp" "github.com/gin-gonic/gin" - "github.com/pelletier/go-toml" "github.com/phuslu/log" "go.sebtobie.de/httpserver" "go.sebtobie.de/httpserver/auth" "gopkg.in/dgrijalva/jwt-go.v3" ) +func musturi(url *url.URL, err error) *url.URL { + if err != nil { + panic(err) + } + return url +} + var ( defaultsaml = &SAML{ - Selfsigned: false, - UnkownAuthority: false, - IDP: "https://samltest.id/saml/idp", - Domain: "example.com", - Cookiename: "ILOVECOOKIES", + idp: musturi(url.ParseRequestURI("https://samltest.id/saml/idp")), + SiteDomain: "example.com", + Cookiename: "ILOVECOOKIES", } _ httpserver.Site = defaultsaml ) @@ -34,76 +39,83 @@ type metadata struct{} // SAML is an Applicance to react on Events from the SAML-IDP and that provides an interface to get data from the IDP in a standartised fashion. type SAML struct { - router *gin.RouterGroup - config *toml.Tree - publicroot string - Keyfiles []string - SPPublicKey string - sppublickey *x509.Certificate - SPPrivatekey string - spprivatekey *rsa.PrivateKey - JWTPrivatekey string - jwtprivatekey *rsa.PrivateKey - Selfsigned bool - UnkownAuthority bool - IDP string `comment:"URL of the Metadata of the IDP"` - sp *saml.ServiceProvider - HTTPClient http.Client `toml:"-"` - Domain string - Cookiename string + router *gin.RouterGroup + publicroot string + Keyfiles []string + SPPublicKey string + sppublickey *x509.Certificate + SPPrivatekey string + spprivatekey *rsa.PrivateKey + JWTPrivatekey string + jwtprivatekey *rsa.PrivateKey + idp *url.URL + sp *saml.ServiceProvider + HTTPClient http.Client `toml:"-"` + SiteDomain string `toml:"domain"` + SitePath string `toml:"path"` + Cookiename string } -// NewSAMLEndpoint creates an endpoint which handles SAML Requests. -func NewSAMLEndpoint(config *toml.Tree) (s *SAML, err error) { - s = &(*defaultsaml) - s.config = config - log.Trace().Str("config", config.String()).Msg("config") +// Setup sets the saml object up. +func (s *SAML) Setup(config httpserver.SiteConfig) (err error) { var key interface{} - s.config = config - if err = config.Unmarshal(s); err != nil { - log.Error().Err(err).Msg("Error while mapping config to struct") - return + if keyfile, found := config["spprivatekey"]; found { + s.SPPrivatekey = keyfile.(string) + key, err = initcert(s.SPPrivatekey, func(key interface{}) bool { + _, ok := key.(*rsa.PrivateKey) + return ok + }) + if err != nil { + return + } + s.spprivatekey = key.(*rsa.PrivateKey) + } else { + return errors.New("SP Privatekey not found") } - key, err = initcert(s.SPPrivatekey, func(key interface{}) bool { - _, ok := key.(*rsa.PrivateKey) - return ok - }) - if err != nil { - return + if keyfile, found := config["sppublickey"]; found { + s.SPPublicKey = keyfile.(string) + key, err = initcert(s.SPPublicKey, func(key interface{}) bool { + _, ok := key.(*x509.Certificate) + return ok + }) + if err != nil { + return + } + s.sppublickey = key.(*x509.Certificate) + } else { + return errors.New("SP Publickey not found") } - s.spprivatekey = key.(*rsa.PrivateKey) - key, err = initcert(s.SPPublicKey, func(key interface{}) bool { - _, ok := key.(*x509.Certificate) - return ok - }) - if err != nil { - return + if keyfile, found := config["jwtprivatekey"]; found { + s.JWTPrivatekey = keyfile.(string) + key, err = initcert(s.JWTPrivatekey, func(key interface{}) bool { + _, ok := key.(*rsa.PrivateKey) + return ok + }) + if err != nil { + return + } + s.jwtprivatekey = key.(*rsa.PrivateKey) + } else { + return errors.New("JWT Privatekey not found") } - s.sppublickey = key.(*x509.Certificate) - - key, err = initcert(s.SPPrivatekey, func(key interface{}) bool { - _, ok := key.(*rsa.PrivateKey) - return ok - }) - if err != nil { - return - } - s.jwtprivatekey = key.(*rsa.PrivateKey) s.sp = &saml.ServiceProvider{ - Key: s.spprivatekey, - Certificate: s.sppublickey, + Key: s.spprivatekey, + Certificate: s.sppublickey, + AuthnNameIDFormat: saml.PersistentNameIDFormat, } - var idpurl *url.URL - idpurl, err = url.ParseRequestURI(s.IDP) - if err != nil { - return + if idp, found := config["idp"]; found { + s.idp, err = url.ParseRequestURI(idp.(string)) + if err != nil { + return + } + s.sp.IDPMetadata, err = samlsp.FetchMetadata(context.Background(), &s.HTTPClient, *s.idp) + if err != nil { + return + } + } else { + err = errors.New("IDP in configfile not found") } - s.sp.IDPMetadata, err = samlsp.FetchMetadata(context.Background(), &s.HTTPClient, *idpurl) - if err != nil { - return - } - s.sp.AuthnNameIDFormat = saml.UnspecifiedNameIDFormat return } @@ -113,12 +125,12 @@ func (s *SAML) Init(router *gin.RouterGroup) { s.router = router s.sp.AcsURL = url.URL{ Scheme: "https", - Host: s.Domain, + Host: s.SiteDomain, Path: s.publicroot + "/acs", } s.sp.MetadataURL = url.URL{ Scheme: "https", - Host: s.Domain, + Host: s.SiteDomain, Path: s.publicroot + "/metadata.xml", } router.GET("/metadata.xml", s.metadataHF) @@ -130,7 +142,7 @@ func (s *SAML) Teardown() {} func (s *SAML) metadataHF(c *gin.Context) { m := s.sp.Metadata() - log.Debug().Time("Validuntil", m.ValidUntil).Msg("SP MEtadata") + log.Debug().Time("Validuntil", m.ValidUntil).Msg("SP Metadata") c.XML(http.StatusOK, m) } @@ -168,3 +180,19 @@ func (s *SAML) acsHF(c *gin.Context) { } c.Redirect(http.StatusSeeOther, redirect) } + +// Domain returns an the configured domain +func (s *SAML) Domain() string { + return s.SiteDomain +} + +// Defaults returns the default values for the config +func (s *SAML) Defaults() httpserver.SiteConfig { + return map[string]interface{}{ + "domain": "example.com", + "idp": defaultsaml.idp, + "sppublickey": "publickey.pem", + "spprivatekey": "privatekey.pem", + "jwtprivatekey": "privatekey.pem", + } +} diff --git a/modules/saml/saml_test.go b/modules/saml/saml_test.go new file mode 100644 index 0000000..9bce37a --- /dev/null +++ b/modules/saml/saml_test.go @@ -0,0 +1,22 @@ +package saml_test + +import ( + "testing" + + "go.sebtobie.de/httpserver" + "go.sebtobie.de/httpserver/auth" + "go.sebtobie.de/httpserver/modules/saml" +) + +func TestSamlMethods(t *testing.T) { + t.Parallel() + var samlo = &saml.SAML{} + var samlsite httpserver.Site = samlo + var _ auth.AuthenticationHandler = samlo + defaults := samlsite.Defaults() + if len(defaults) == 0 { + t.Log("There is an empty Default Object") + t.Fail() + } + samlsite.Setup(defaults) +} diff --git a/site.go b/site.go new file mode 100644 index 0000000..b8c6c93 --- /dev/null +++ b/site.go @@ -0,0 +1,25 @@ +package httpserver + +import "github.com/gin-gonic/gin" + +// Site is an Interface to abstract the modularized group of pages. +// The Middleware must be able to detect multiple calls by itself. Deduplication is not performed. +type Site interface { + Setup(SiteConfig) error + Init(*gin.RouterGroup) + Teardown() + Defaults() SiteConfig +} + +// SiteConfig is an interface for configitems of the site. The methods return the required items for the server +type SiteConfig map[string]interface{} + +// Domain gives an easier access to the domain value +func (sc SiteConfig) Domain() string { + return sc["domain"].(string) +} + +// Path gives an easier access to the path value +func (sc SiteConfig) Path() string { + return sc["path"].(string) +}