package saml import ( "context" "crypto/rsa" "crypto/x509" "errors" "fmt" "net/http" "net/url" "path" "time" "github.com/crewjam/saml" "github.com/crewjam/saml/samlsp" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v4" "github.com/rs/zerolog/log" "go.sebtobie.de/httpserver" "go.sebtobie.de/httpserver/auth" "go.sebtobie.de/httpserver/constants" ) func musturi(url *url.URL, err error) *url.URL { if err != nil { panic(err) } return url } var ( defaultsaml = &SAML{ idp: musturi(url.ParseRequestURI("https://samltest.id/saml/idp")), Domain: "example.com", Cookiename: "ILOVECOOKIES", } _ httpserver.Site = defaultsaml _ httpserver.ConfigSite = defaultsaml ) type metadata struct{} // SAML is an Applicance to react on Events from the SAML-IDP and that provides an interface to get data from the IDP in a standartised fashion. type SAML struct { router *gin.RouterGroup publicroot string SPPublicKey string sppublickey *x509.Certificate SPPrivatekey string spprivatekey *rsa.PrivateKey JWTPrivatekey string jwtprivatekey *rsa.PrivateKey idp *url.URL sp *saml.ServiceProvider HTTPClient http.Client `toml:"-"` Domain string `toml:"domain"` Cookiename string } // Setup sets the saml object up. func (s *SAML) Setup(config httpserver.SiteConfig) (err error) { log.Info().Msg("Setting up SAML service provider") s.Domain = config["domain"].(string) s.sp = &saml.ServiceProvider{ AcsURL: url.URL{ Scheme: "https", Host: s.Domain, Path: path.Join(s.publicroot, "acs"), }, MetadataURL: url.URL{ Scheme: "https", Host: s.Domain, Path: path.Join(s.publicroot, "metadata.xml"), }, } switch config["metadatavalid"].(type) { case time.Duration: s.sp.MetadataValidDuration = config["metadatavalid"].(time.Duration) case int: s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int)) case int8: s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int8)) case int16: s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int16)) case int32: s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int32)) case int64: s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int64)) } var key interface{} if keyfile, found := config["spprivatekey"]; found { s.SPPrivatekey = keyfile.(string) key, err = initcert(s.SPPrivatekey, func(key interface{}) bool { _, ok := key.(*rsa.PrivateKey) return ok }) if err != nil { return } s.spprivatekey = key.(*rsa.PrivateKey) } else { return errors.New("SP Privatekey not found") } if keyfile, found := config["sppublickey"]; found { s.SPPublicKey = keyfile.(string) key, err = initcert(s.SPPublicKey, func(key interface{}) bool { _, ok := key.(*x509.Certificate) return ok }) if err != nil { return } s.sppublickey = key.(*x509.Certificate) } else { return errors.New("SP Publickey not found") } if keyfile, found := config["jwtprivatekey"]; found { s.JWTPrivatekey = keyfile.(string) key, err = initcert(s.JWTPrivatekey, func(key interface{}) bool { _, ok := key.(*rsa.PrivateKey) return ok }) if err != nil { return } s.jwtprivatekey = key.(*rsa.PrivateKey) } else { return errors.New("JWT Privatekey not found") } s.sp = &saml.ServiceProvider{ Key: s.spprivatekey, Certificate: s.sppublickey, AuthnNameIDFormat: saml.PersistentNameIDFormat, } if idp, found := config["idp"]; found { s.idp, err = url.ParseRequestURI(idp.(string)) if err != nil { return } s.sp.IDPMetadata, err = samlsp.FetchMetadata(context.Background(), &s.HTTPClient, *s.idp) if err != nil { return } } else { err = errors.New("IDP in configfile not found") } return } // Init initalizes the routes func (s *SAML) Init(router *gin.RouterGroup) { s.publicroot = router.BasePath() s.router = router router.GET("metadata.xml", s.metadataHF) router.POST("acs", s.acsHF) } func (s *SAML) metadataHF(c *gin.Context) { if s.sp == nil { c.AbortWithStatus(500) return } m := s.sp.Metadata() if m == nil { c.AbortWithStatus(500) return } c.XML(http.StatusOK, m) } func (s *SAML) acsHF(c *gin.Context) { account := c.MustGet("account").(auth.Account) err := c.Request.ParseForm() if err != nil { c.AbortWithError(http.StatusNotAcceptable, err) } var assert *saml.Assertion assert, err = s.sp.ParseResponse(c.Request, []string{account.Get("jti").(string)}) if err != nil { realerr, _ := err.(*saml.InvalidResponseError) err = realerr.PrivateErr log.Error().AnErr("Assertionerror", err).Msgf("Assertion Error") fmt.Print(realerr.Response) c.AbortWithStatus(http.StatusBadRequest) return } data := attributeStatementstomap(assert.AttributeStatements) token, err := jwttoken(jwt.MapClaims{ string(constants.AccountAnon): false, string(constants.AccountID): account.Get(constants.AccountID).(string), string(constants.AccountUser): data["uid"][0], }, s.jwtprivatekey) if err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } c.SetCookie(s.Cookiename, token, int(time.Hour*24*30), "", "", true, true) redirect, found := c.GetPostForm("RelayState") if !found { c.AbortWithStatus(http.StatusNotAcceptable) return } c.Redirect(http.StatusSeeOther, redirect) } // Defaults returns the default values for the config func (s *SAML) Defaults() httpserver.SiteConfig { return map[string]interface{}{ "domain": "example.com", "idp": defaultsaml.idp.String(), "sppublickey": "publickey.pem", "spprivatekey": "privatekey.pem", "jwtprivatekey": "privatekey.pem", "metadatavalid": time.Duration(time.Hour * 24), } }