From 6d43861ffc94c715f3991db34a98dba81212d496 Mon Sep 17 00:00:00 2001 From: Sebastian Tobie Date: Sun, 7 Nov 2021 18:51:39 +0100 Subject: [PATCH] some bigger changes - The detection of unix and TCP sockets is much better; - The configuration is now easier exachangable; - middleware doesn't get lost; - the defaults func is now supplying the config it needs into the defaults; --- funcs/funcs.go | 22 ++++++++++++ funcs/funcs_test.go | 47 +++++++++++++++++++++++++ go.mod | 2 +- go.sum | 2 ++ http.go | 82 +++++++++++++++++++++++++++++--------------- modules/saml/saml.go | 66 ++++++++++++++++++++++------------- site.go | 10 ------ 7 files changed, 170 insertions(+), 61 deletions(-) create mode 100644 funcs/funcs_test.go diff --git a/funcs/funcs.go b/funcs/funcs.go index fd96b98..d386928 100644 --- a/funcs/funcs.go +++ b/funcs/funcs.go @@ -2,6 +2,7 @@ package funcs import ( "context" + "net" "github.com/jackc/pgx/v4" "github.com/phuslu/log" @@ -34,3 +35,24 @@ func (PGXLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data m } var _ pgx.Logger = PGXLogger{} + +// IsUnix tests if the address is an unix address. It returns false if its an tcp address. +func IsUnix(address string) bool { + if IsTCP(address) { + return false + } + + if _, err := net.ResolveUnixAddr("unix", address); err == nil { + return true + } + return false +} + +// IsTCP tests if the address is an tcp address +func IsTCP(address string) bool { + _, err := net.ResolveTCPAddr("tcp", address) + if err == nil { + return true + } + return false +} diff --git a/funcs/funcs_test.go b/funcs/funcs_test.go new file mode 100644 index 0000000..ad0b1fb --- /dev/null +++ b/funcs/funcs_test.go @@ -0,0 +1,47 @@ +package funcs_test + +import ( + "testing" + + "go.sebtobie.de/httpserver/funcs" +) + +type testtype struct { + name string + arg string + want bool +} + +func TestIsUnix(t *testing.T) { + tests := []testtype{ + {name: "absolute unix", arg: "/tmp/xyz", want: true}, + {name: "relative unix", arg: "xyz", want: true}, + {name: "simple tcp4 address", arg: "127.0.0.1:0", want: false}, + {name: "simple tcp6 address", arg: "[::1]:0", want: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := funcs.IsUnix(tt.arg); got != tt.want { + t.Errorf("IsUnix() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsTCP(t *testing.T) { + tests := []testtype{ + {name: "simple tcp4 address", arg: "127.0.0.1:0", want: true}, + {name: "simple tcp6 address", arg: "[::1]:0", want: true}, + {name: "too big ip address", arg: "999.0.0.1:0", want: false}, + {name: "too big port", arg: "127.0.0.1:65536", want: false}, + {name: "unix relative", arg: "xyz.sock", want: false}, + {name: "unix absolute", arg: "/xyz.sock", want: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := funcs.IsTCP(tt.arg); got != tt.want { + t.Errorf("IsTCP() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/go.mod b/go.mod index 1cc6a5d..3cdf47d 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,7 @@ require ( github.com/shopspring/decimal v1.2.0 // indirect github.com/ugorji/go v1.2.6 // indirect golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 // indirect - golang.org/x/sys v0.0.0-20211023085530-d6a326fbbf70 // indirect + golang.org/x/sys v0.0.0-20211106132015-ebca88c72f68 // indirect golang.org/x/text v0.3.7 // indirect google.golang.org/protobuf v1.27.1 // indirect gopkg.in/dgrijalva/jwt-go.v3 v3.2.0 diff --git a/go.sum b/go.sum index 1cd88fe..632755e 100644 --- a/go.sum +++ b/go.sum @@ -351,6 +351,8 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211023085530-d6a326fbbf70 h1:SeSEfdIxyvwGJliREIJhRPPXvW6sDlLT+UQ3B0hD0NA= golang.org/x/sys v0.0.0-20211023085530-d6a326fbbf70/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211106132015-ebca88c72f68 h1:Ywe/f3fNleF8I6F6qv3MeFoSZ6CTf2zBMMa/7qVML8M= +golang.org/x/sys v0.0.0-20211106132015-ebca88c72f68/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/http.go b/http.go index fca40bb..682003c 100644 --- a/http.go +++ b/http.go @@ -11,26 +11,29 @@ import ( "github.com/gin-gonic/gin" "github.com/phuslu/log" "go.sebtobie.de/httpserver/auth" + "go.sebtobie.de/httpserver/funcs" "go.sebtobie.de/httpserver/menus" "go.sebtobie.de/httpserver/templates" ) func init() { gin.DebugPrintRouteFunc = func(httpMethod, absolutePath, handlerName string, nuHandlers int) { - log.Trace().Msgf("%-4s(%02d): %-20s %s", httpMethod, nuHandlers, absolutePath, handlerName) + log.Debug().Msgf("%-4s(%02d): %-20s %s", httpMethod, nuHandlers, absolutePath, handlerName) } + gin.SetMode(gin.DebugMode) } // Config that is used to map the toml config to the settings that are used. type Config struct { Addr []string TLSAddr []string - TLSconfig *tls.Config + TLSconfig *tls.Config `toml:"-"` Certfile string Keyfile string Sites map[string]SiteConfig } +/** // MarshalObject adds the information over the object to the *log.Entry func (c *Config) MarshalObject(e *log.Entry) { e.Strs("Address", c.Addr).Bool("TLS", c.TLSconfig != nil) @@ -38,9 +41,11 @@ func (c *Config) MarshalObject(e *log.Entry) { e.Str("Certfile", c.Certfile) e.Str("Keyfile", c.Keyfile) } + e.Int("sites", len(c.Sites)) } var _ log.ObjectMarshaler = &Config{} +/**/ // Server is an wrapper for the *http.Server and *gin.Engine type Server struct { @@ -54,6 +59,7 @@ type Server struct { routines sync.WaitGroup setup bool authh auth.AuthenticationHandler + middleware gin.HandlersChain } // CreateServer creates an server that can be run in a coroutine. @@ -63,11 +69,16 @@ func CreateServer() *Server { gin.DefaultWriter = log.DefaultLogger.Std(log.DebugLevel, log.Context{}, "GIN", 0).Writer() log.Info().Msg("Creating HTTP-Server") var server = &Server{ - Conf: &Config{}, + Conf: &Config{ + TLSconfig: &tls.Config{}, + Sites: map[string]SiteConfig{}, + }, mrouter: map[string]*gin.Engine{}, authh: &auth.AnonAccountHandler{}, menu: []menus.Menu{}, NotFoundHandler: http.NotFoundHandler(), + sites: map[string]Site{}, + middleware: gin.HandlersChain{}, } server.http = &http.Server{ ErrorLog: log.DefaultLogger.Std(log.ErrorLevel, log.Context{}, "", 0), @@ -78,25 +89,36 @@ func CreateServer() *Server { // runPort runs a listener on the port. his enables th server to serve more than a address. func (s *Server) runPort(address string, tls bool) { + defer s.routines.Done() var socket net.Listener var err error - if net.ParseIP(address) != nil { + var unix string + + if funcs.IsTCP(address) { socket, err = net.Listen("tcp", address) - } else { + } + if funcs.IsUnix(address) { + unix = "Unix-" socket, err = net.Listen("unix", address) } if err != nil { log.Error().Err(err).Msgf("failed to open socket on %s", address) + return + } + if socket == nil { + log.Error().Msg("Failed to identify the sockettype") + return } if tls { + log.Info().Msgf("starting listen on secure %ssocket %s", unix, address) err = s.http.ServeTLS(socket, s.Conf.Certfile, s.Conf.Keyfile) } else { + log.Info().Msgf("starting listen on %ssocket %s", unix, address) err = s.http.Serve(socket) } if err != http.ErrServerClosed { log.Error().Err(err).Msg("Socket unexpected exited") } - s.routines.Done() } // SetAuthentication sets the handler that is responsible for authentication @@ -113,7 +135,6 @@ func (s *Server) StartServer() { } log.Info().Msg("Starting server") s.http.TLSConfig = s.Conf.TLSconfig - var err error if s.Conf.Certfile != "" && s.Conf.Keyfile != "" { for _, addr := range s.Conf.TLSAddr { s.routines.Add(1) @@ -121,11 +142,9 @@ func (s *Server) StartServer() { } } for _, addr := range s.Conf.Addr { + s.routines.Add(1) go s.runPort(addr, false) } - if err != http.ErrServerClosed { - log.Error().Err(err).Msg("Server unexpected exited") - } s.routines.Wait() } @@ -145,24 +164,27 @@ func (s *Server) DomainRouter(w http.ResponseWriter, r *http.Request) { log.Trace().Strs(header, value).Msg("Headers") } if router, found := s.mrouter[domain]; found { + log.Info().Bool("isnil", router == nil).Str("domain", domain).Msg("Debuginfo") router.NoMethod(gin.WrapH(s.NotFoundHandler)) router.NoRoute(gin.WrapH(s.NotFoundHandler)) router.ServeHTTP(w, r) return } + log.Error().Msgf("Failed to find domain for %s", domain) var entrys []string for d := range s.mrouter { entrys = append(entrys, d) } - log.Trace().Strs("reqistred domains", entrys).Msg("domain not found") + log.Trace().Strs("registred domains", entrys).Msg("domain not found") s.NotFoundHandler.ServeHTTP(w, r) } // Use installs the middleware into the router. // The Middleware must be able to detect multiple calls byy itself. Deduplication is not performed. func (s *Server) Use(m ...gin.HandlerFunc) { - for _, router := range s.mrouter { - router.Use(m...) + s.middleware = append(s.middleware, m...) + for _, site := range s.mrouter { + site.Use(m...) } } @@ -178,37 +200,37 @@ func (s *Server) menus() []menus.Menu { return s.menu } -// Setup ets the server up. It loads the sites and prepare the server for startup. +// Setup sets the server up. It loads the sites and prepare the server for startup. // The sites get their config in this step. func (s *Server) Setup() { var router *gin.Engine var found bool for cfg, site := range s.sites { config := s.Conf.Sites[cfg] - if router, found = s.mrouter[config.Domain()]; !found { + if router, found = s.mrouter[config["domain"].(string)]; !found { router = gin.New() - mw := []gin.HandlerFunc{ - func(c *gin.Context) { - c.Set(Menus, s.menus) - c.Set(Domain, config.Domain()) - }, - } - mw = append(mw, func(c *gin.Context) { s.authh.Account(c) }) - router.Use(mw...) - s.mrouter[config.Domain()] = router + router.Use(func(c *gin.Context) { + c.Set(Domain, config["domain"]) + c.Set(Menus, s.menus) + c.Set(Accounts, s.authh.Account(c)) + }) + router.Use(s.middleware...) + s.mrouter[config["domain"].(string)] = router } - site.Init(router.Group(config.Path())) + group := router.Group(config["path"].(string)) + site.Init(group) if ms, ok := site.(menus.MenuSite); ok { - menus := ms.Menu(config.Domain()) + menus := ms.Menu(config["domain"].(string)) log.Debug().Msgf("%d menus are added", len(menus)) s.menu = append(s.menu, menus...) } if ts, ok := site.(templates.TemplateSite); ok { templates := ts.Templates() - log.Debug().Msgf("Templates for %s%s are added", config.Domain(), config.Path()) + log.Debug().Msgf("Templates for %s%s are added", config["domain"], config["path"].(string)) s.template.AddParseTree(templates.Name(), templates.Tree) s.template.Funcs(ts.Funcs()) } + site.Setup(config) } s.setup = true } @@ -217,6 +239,12 @@ func (s *Server) Setup() { // it registers the defaults so that the application can load/dump it from/into an configfile or commandline options func (s *Server) RegisterSite(cfg string, site Site) { var config = site.Defaults() + if _, found := config["domain"]; !found { + config["domain"] = "" + } + if _, found := config["path"]; !found { + config["path"] = "" + } s.Conf.Sites[cfg] = config s.sites[cfg] = site } diff --git a/modules/saml/saml.go b/modules/saml/saml.go index a2926ab..9318247 100644 --- a/modules/saml/saml.go +++ b/modules/saml/saml.go @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "net/url" + "path" "time" "github.com/crewjam/saml" @@ -29,7 +30,7 @@ func musturi(url *url.URL, err error) *url.URL { var ( defaultsaml = &SAML{ idp: musturi(url.ParseRequestURI("https://samltest.id/saml/idp")), - SiteDomain: "example.com", + Domain: "example.com", Cookiename: "ILOVECOOKIES", } _ httpserver.Site = defaultsaml @@ -41,7 +42,6 @@ type metadata struct{} type SAML struct { router *gin.RouterGroup publicroot string - Keyfiles []string SPPublicKey string sppublickey *x509.Certificate SPPrivatekey string @@ -51,13 +51,40 @@ type SAML struct { idp *url.URL sp *saml.ServiceProvider HTTPClient http.Client `toml:"-"` - SiteDomain string `toml:"domain"` - SitePath string `toml:"path"` + Domain string `toml:"domain"` Cookiename string } // Setup sets the saml object up. func (s *SAML) Setup(config httpserver.SiteConfig) (err error) { + log.Info().Msg("Setting up SAML service provider") + s.Domain = config["domain"].(string) + s.sp = &saml.ServiceProvider{ + AcsURL: url.URL{ + Scheme: "https", + Host: s.Domain, + Path: path.Join(s.publicroot, "acs"), + }, + MetadataURL: url.URL{ + Scheme: "https", + Host: s.Domain, + Path: path.Join(s.publicroot, "metadata.xml"), + }, + } + switch config["metadatavalid"].(type) { + case time.Duration: + s.sp.MetadataValidDuration = config["metadatavalid"].(time.Duration) + case int: + s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int)) + case int8: + s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int8)) + case int16: + s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int16)) + case int32: + s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int32)) + case int64: + s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int64)) + } var key interface{} if keyfile, found := config["spprivatekey"]; found { s.SPPrivatekey = keyfile.(string) @@ -123,26 +150,23 @@ func (s *SAML) Setup(config httpserver.SiteConfig) (err error) { func (s *SAML) Init(router *gin.RouterGroup) { s.publicroot = router.BasePath() s.router = router - s.sp.AcsURL = url.URL{ - Scheme: "https", - Host: s.SiteDomain, - Path: s.publicroot + "/acs", - } - s.sp.MetadataURL = url.URL{ - Scheme: "https", - Host: s.SiteDomain, - Path: s.publicroot + "/metadata.xml", - } - router.GET("/metadata.xml", s.metadataHF) - router.POST("/acs", s.acsHF) + router.GET("metadata.xml", s.metadataHF) + router.POST("acs", s.acsHF) } // Teardown is to satisfy the httpserver.Site interface. func (s *SAML) Teardown() {} func (s *SAML) metadataHF(c *gin.Context) { + if s.sp == nil { + c.AbortWithStatus(500) + return + } m := s.sp.Metadata() - log.Debug().Time("Validuntil", m.ValidUntil).Msg("SP Metadata") + if m == nil { + c.AbortWithStatus(500) + return + } c.XML(http.StatusOK, m) } @@ -181,18 +205,14 @@ func (s *SAML) acsHF(c *gin.Context) { c.Redirect(http.StatusSeeOther, redirect) } -// Domain returns an the configured domain -func (s *SAML) Domain() string { - return s.SiteDomain -} - // Defaults returns the default values for the config func (s *SAML) Defaults() httpserver.SiteConfig { return map[string]interface{}{ "domain": "example.com", - "idp": defaultsaml.idp, + "idp": defaultsaml.idp.String(), "sppublickey": "publickey.pem", "spprivatekey": "privatekey.pem", "jwtprivatekey": "privatekey.pem", + "metadatavalid": time.Duration(time.Hour * 24), } } diff --git a/site.go b/site.go index b8c6c93..551ea1e 100644 --- a/site.go +++ b/site.go @@ -13,13 +13,3 @@ type Site interface { // SiteConfig is an interface for configitems of the site. The methods return the required items for the server type SiteConfig map[string]interface{} - -// Domain gives an easier access to the domain value -func (sc SiteConfig) Domain() string { - return sc["domain"].(string) -} - -// Path gives an easier access to the path value -func (sc SiteConfig) Path() string { - return sc["path"].(string) -}