package saml import ( "context" "crypto/rsa" "crypto/x509" "fmt" "net/http" "net/url" "time" "github.com/crewjam/saml" "github.com/crewjam/saml/samlsp" "github.com/gin-gonic/gin" "github.com/pelletier/go-toml" "github.com/phuslu/log" "go.sebtobie.de/httpserver" "gopkg.in/dgrijalva/jwt-go.v3" ) const ( HJWT = "jwt" HSPPrivate = "sppriv" HSPPublic = "sppub" ) var ( defaultsaml = &SAML{ Selfsigned: false, UnkownAuthority: false, IDP: "https://samltest.id/saml/idp", Domain: "example.com", Cookiename: "ILOVECOOKIES", } ) type metadata struct{} // SAML is an Applicance to react on Events from the SAML-IDP and that provides an interface to get data from the IDP in a standartised fashion. type SAML struct { router *gin.RouterGroup config *toml.Tree publicroot string Keyfiles []string SPPublicKey string sppublickey *x509.Certificate SPPrivatekey string spprivatekey *rsa.PrivateKey JWTPrivatekey string jwtprivatekey *rsa.PrivateKey Selfsigned bool UnkownAuthority bool IDP string `comment:"URL of the Metadata of the IDP"` sp *saml.ServiceProvider HttpClient http.Client `toml:"-"` Domain string Cookiename string } // NewSAMLEndpoint creates an endpoint which handles SAML Requests. func NewSAMLEndpoint(config *toml.Tree) (*SAML, error) { var key interface{} var err error var s SAML = *defaultsaml s.config = config if err := config.Unmarshal(&s); err != nil { log.Error().Err(err).Msg("Error while mapping config to struct") return nil, err } log.Trace().Interface("config", config).Msg("cofnig") key, err = initcert(s.SPPrivatekey, func(key interface{}) bool { _, ok := key.(*rsa.PrivateKey) return ok }) if err != nil { return nil, err } s.spprivatekey = key.(*rsa.PrivateKey) key, err = initcert(s.SPPublicKey, func(key interface{}) bool { _, ok := key.(*x509.Certificate) return ok }) if err != nil { return nil, err } s.sppublickey = key.(*x509.Certificate) key, err = initcert(s.SPPrivatekey, func(key interface{}) bool { _, ok := key.(*rsa.PrivateKey) return ok }) if err != nil { return nil, err } s.jwtprivatekey = key.(*rsa.PrivateKey) s.sp = &saml.ServiceProvider{ Key: s.spprivatekey, Certificate: s.sppublickey, } var idpurl *url.URL idpurl, err = url.ParseRequestURI(s.IDP) if err != nil { return nil, err } s.sp.IDPMetadata, err = samlsp.FetchMetadata(context.Background(), &s.HttpClient, *idpurl) if err != nil { return nil, err } s.sp.AuthnNameIDFormat = saml.UnspecifiedNameIDFormat return &s, nil } // Init initalizes the routes func (s *SAML) Init(router *gin.RouterGroup) { s.publicroot = router.BasePath() s.router = router s.sp.AcsURL = url.URL{ Scheme: "https", Host: s.Domain, Path: s.publicroot + "/acs", } s.sp.MetadataURL = url.URL{ Scheme: "https", Host: s.Domain, Path: s.publicroot + "/metadata.xml", } router.GET("/metadata.xml", s.metadataHF) router.POST("/acs", s.acsHF) } // Middleware returns the Required Middleware func (s *SAML) Middleware() []gin.HandlerFunc { return []gin.HandlerFunc{} } func (s *SAML) metadataHF(c *gin.Context) { m := s.sp.Metadata() log.Debug().Time("Validuntil", m.ValidUntil).Msg("SP MEtadata") c.XML(http.StatusOK, m) } func (s *SAML) acsHF(c *gin.Context) { account := c.MustGet("account").(httpserver.Account) err := c.Request.ParseForm() if err != nil { c.AbortWithError(http.StatusNotAcceptable, err) } var assert *saml.Assertion assert, err = s.sp.ParseResponse(c.Request, []string{account.Get("jti").(string)}) if err != nil { realerr, _ := err.(*saml.InvalidResponseError) err = realerr.PrivateErr log.Error().AnErr("Assertionerror", err).Msgf("Assertion Error") fmt.Print(realerr.Response) c.AbortWithStatus(http.StatusBadRequest) return } data := attributeStatementstomap(assert.AttributeStatements) token, err := jwttoken(jwt.MapClaims{ httpserver.AccountAnon: false, httpserver.AccountID: account.Get(httpserver.AccountID).(string), httpserver.AccountUser: data["uid"][0], }, s.jwtprivatekey) if err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } c.SetCookie(s.Cookiename, token, int(time.Hour*24*30), "", "", true, true) redirect, found := c.GetPostForm("RelayState") if !found { c.AbortWithStatus(http.StatusNotAcceptable) return } c.Redirect(http.StatusSeeOther, redirect) }