mirror of
				https://codeberg.org/forgejo/forgejo.git
				synced 2025-10-24 19:12:24 +00:00 
			
		
		
		
	manually backport for https://github.com/go-gitea/gitea/pull/27995 The conflict is `ctx` and `db.Defaultctx`.
		
			
				
	
	
		
			414 lines
		
	
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			414 lines
		
	
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright 2014 The Gogs Authors. All rights reserved.
 | |
| // Copyright 2019 The Gitea Authors. All rights reserved.
 | |
| // SPDX-License-Identifier: MIT
 | |
| 
 | |
| package auth
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"reflect"
 | |
| 
 | |
| 	"code.gitea.io/gitea/models/db"
 | |
| 	"code.gitea.io/gitea/modules/log"
 | |
| 	"code.gitea.io/gitea/modules/timeutil"
 | |
| 	"code.gitea.io/gitea/modules/util"
 | |
| 
 | |
| 	"xorm.io/xorm"
 | |
| 	"xorm.io/xorm/convert"
 | |
| )
 | |
| 
 | |
| // Type represents an login type.
 | |
| type Type int
 | |
| 
 | |
| // Note: new type must append to the end of list to maintain compatibility.
 | |
| const (
 | |
| 	NoType Type = iota
 | |
| 	Plain       // 1
 | |
| 	LDAP        // 2
 | |
| 	SMTP        // 3
 | |
| 	PAM         // 4
 | |
| 	DLDAP       // 5
 | |
| 	OAuth2      // 6
 | |
| 	SSPI        // 7
 | |
| )
 | |
| 
 | |
| // String returns the string name of the LoginType
 | |
| func (typ Type) String() string {
 | |
| 	return Names[typ]
 | |
| }
 | |
| 
 | |
| // Int returns the int value of the LoginType
 | |
| func (typ Type) Int() int {
 | |
| 	return int(typ)
 | |
| }
 | |
| 
 | |
| // Names contains the name of LoginType values.
 | |
| var Names = map[Type]string{
 | |
| 	LDAP:   "LDAP (via BindDN)",
 | |
| 	DLDAP:  "LDAP (simple auth)", // Via direct bind
 | |
| 	SMTP:   "SMTP",
 | |
| 	PAM:    "PAM",
 | |
| 	OAuth2: "OAuth2",
 | |
| 	SSPI:   "SPNEGO with SSPI",
 | |
| }
 | |
| 
 | |
| // Config represents login config as far as the db is concerned
 | |
| type Config interface {
 | |
| 	convert.Conversion
 | |
| }
 | |
| 
 | |
| // SkipVerifiable configurations provide a IsSkipVerify to check if SkipVerify is set
 | |
| type SkipVerifiable interface {
 | |
| 	IsSkipVerify() bool
 | |
| }
 | |
| 
 | |
| // HasTLSer configurations provide a HasTLS to check if TLS can be enabled
 | |
| type HasTLSer interface {
 | |
| 	HasTLS() bool
 | |
| }
 | |
| 
 | |
| // UseTLSer configurations provide a HasTLS to check if TLS is enabled
 | |
| type UseTLSer interface {
 | |
| 	UseTLS() bool
 | |
| }
 | |
| 
 | |
| // SSHKeyProvider configurations provide ProvidesSSHKeys to check if they provide SSHKeys
 | |
| type SSHKeyProvider interface {
 | |
| 	ProvidesSSHKeys() bool
 | |
| }
 | |
| 
 | |
| // RegisterableSource configurations provide RegisterSource which needs to be run on creation
 | |
| type RegisterableSource interface {
 | |
| 	RegisterSource() error
 | |
| 	UnregisterSource() error
 | |
| }
 | |
| 
 | |
| var registeredConfigs = map[Type]func() Config{}
 | |
| 
 | |
| // RegisterTypeConfig register a config for a provided type
 | |
| func RegisterTypeConfig(typ Type, exemplar Config) {
 | |
| 	if reflect.TypeOf(exemplar).Kind() == reflect.Ptr {
 | |
| 		// Pointer:
 | |
| 		registeredConfigs[typ] = func() Config {
 | |
| 			return reflect.New(reflect.ValueOf(exemplar).Elem().Type()).Interface().(Config)
 | |
| 		}
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	// Not a Pointer
 | |
| 	registeredConfigs[typ] = func() Config {
 | |
| 		return reflect.New(reflect.TypeOf(exemplar)).Elem().Interface().(Config)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // SourceSettable configurations can have their authSource set on them
 | |
| type SourceSettable interface {
 | |
| 	SetAuthSource(*Source)
 | |
| }
 | |
| 
 | |
| // Source represents an external way for authorizing users.
 | |
| type Source struct {
 | |
| 	ID            int64 `xorm:"pk autoincr"`
 | |
| 	Type          Type
 | |
| 	Name          string             `xorm:"UNIQUE"`
 | |
| 	IsActive      bool               `xorm:"INDEX NOT NULL DEFAULT false"`
 | |
| 	IsSyncEnabled bool               `xorm:"INDEX NOT NULL DEFAULT false"`
 | |
| 	Cfg           convert.Conversion `xorm:"TEXT"`
 | |
| 
 | |
| 	CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
 | |
| 	UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
 | |
| }
 | |
| 
 | |
| // TableName xorm will read the table name from this method
 | |
| func (Source) TableName() string {
 | |
| 	return "login_source"
 | |
| }
 | |
| 
 | |
| func init() {
 | |
| 	db.RegisterModel(new(Source))
 | |
| }
 | |
| 
 | |
| // BeforeSet is invoked from XORM before setting the value of a field of this object.
 | |
| func (source *Source) BeforeSet(colName string, val xorm.Cell) {
 | |
| 	if colName == "type" {
 | |
| 		typ := Type(db.Cell2Int64(val))
 | |
| 		constructor, ok := registeredConfigs[typ]
 | |
| 		if !ok {
 | |
| 			return
 | |
| 		}
 | |
| 		source.Cfg = constructor()
 | |
| 		if settable, ok := source.Cfg.(SourceSettable); ok {
 | |
| 			settable.SetAuthSource(source)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // TypeName return name of this login source type.
 | |
| func (source *Source) TypeName() string {
 | |
| 	return Names[source.Type]
 | |
| }
 | |
| 
 | |
| // IsLDAP returns true of this source is of the LDAP type.
 | |
| func (source *Source) IsLDAP() bool {
 | |
| 	return source.Type == LDAP
 | |
| }
 | |
| 
 | |
| // IsDLDAP returns true of this source is of the DLDAP type.
 | |
| func (source *Source) IsDLDAP() bool {
 | |
| 	return source.Type == DLDAP
 | |
| }
 | |
| 
 | |
| // IsSMTP returns true of this source is of the SMTP type.
 | |
| func (source *Source) IsSMTP() bool {
 | |
| 	return source.Type == SMTP
 | |
| }
 | |
| 
 | |
| // IsPAM returns true of this source is of the PAM type.
 | |
| func (source *Source) IsPAM() bool {
 | |
| 	return source.Type == PAM
 | |
| }
 | |
| 
 | |
| // IsOAuth2 returns true of this source is of the OAuth2 type.
 | |
| func (source *Source) IsOAuth2() bool {
 | |
| 	return source.Type == OAuth2
 | |
| }
 | |
| 
 | |
| // IsSSPI returns true of this source is of the SSPI type.
 | |
| func (source *Source) IsSSPI() bool {
 | |
| 	return source.Type == SSPI
 | |
| }
 | |
| 
 | |
| // HasTLS returns true of this source supports TLS.
 | |
| func (source *Source) HasTLS() bool {
 | |
| 	hasTLSer, ok := source.Cfg.(HasTLSer)
 | |
| 	return ok && hasTLSer.HasTLS()
 | |
| }
 | |
| 
 | |
| // UseTLS returns true of this source is configured to use TLS.
 | |
| func (source *Source) UseTLS() bool {
 | |
| 	useTLSer, ok := source.Cfg.(UseTLSer)
 | |
| 	return ok && useTLSer.UseTLS()
 | |
| }
 | |
| 
 | |
| // SkipVerify returns true if this source is configured to skip SSL
 | |
| // verification.
 | |
| func (source *Source) SkipVerify() bool {
 | |
| 	skipVerifiable, ok := source.Cfg.(SkipVerifiable)
 | |
| 	return ok && skipVerifiable.IsSkipVerify()
 | |
| }
 | |
| 
 | |
| // CreateSource inserts a AuthSource in the DB if not already
 | |
| // existing with the given name.
 | |
| func CreateSource(source *Source) error {
 | |
| 	has, err := db.GetEngine(db.DefaultContext).Where("name=?", source.Name).Exist(new(Source))
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	} else if has {
 | |
| 		return ErrSourceAlreadyExist{source.Name}
 | |
| 	}
 | |
| 	// Synchronization is only available with LDAP for now
 | |
| 	if !source.IsLDAP() {
 | |
| 		source.IsSyncEnabled = false
 | |
| 	}
 | |
| 
 | |
| 	_, err = db.GetEngine(db.DefaultContext).Insert(source)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if !source.IsActive {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	if settable, ok := source.Cfg.(SourceSettable); ok {
 | |
| 		settable.SetAuthSource(source)
 | |
| 	}
 | |
| 
 | |
| 	registerableSource, ok := source.Cfg.(RegisterableSource)
 | |
| 	if !ok {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	err = registerableSource.RegisterSource()
 | |
| 	if err != nil {
 | |
| 		// remove the AuthSource in case of errors while registering configuration
 | |
| 		if _, err := db.GetEngine(db.DefaultContext).ID(source.ID).Delete(new(Source)); err != nil {
 | |
| 			log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
 | |
| 		}
 | |
| 	}
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| // Sources returns a slice of all login sources found in DB.
 | |
| func Sources() ([]*Source, error) {
 | |
| 	auths := make([]*Source, 0, 6)
 | |
| 	return auths, db.GetEngine(db.DefaultContext).Find(&auths)
 | |
| }
 | |
| 
 | |
| // SourcesByType returns all sources of the specified type
 | |
| func SourcesByType(loginType Type) ([]*Source, error) {
 | |
| 	sources := make([]*Source, 0, 1)
 | |
| 	if err := db.GetEngine(db.DefaultContext).Where("type = ?", loginType).Find(&sources); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return sources, nil
 | |
| }
 | |
| 
 | |
| // AllActiveSources returns all active sources
 | |
| func AllActiveSources() ([]*Source, error) {
 | |
| 	sources := make([]*Source, 0, 5)
 | |
| 	if err := db.GetEngine(db.DefaultContext).Where("is_active = ?", true).Find(&sources); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return sources, nil
 | |
| }
 | |
| 
 | |
| // ActiveSources returns all active sources of the specified type
 | |
| func ActiveSources(tp Type) ([]*Source, error) {
 | |
| 	sources := make([]*Source, 0, 1)
 | |
| 	if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return sources, nil
 | |
| }
 | |
| 
 | |
| // IsSSPIEnabled returns true if there is at least one activated login
 | |
| // source of type LoginSSPI
 | |
| func IsSSPIEnabled() bool {
 | |
| 	if !db.HasEngine {
 | |
| 		return false
 | |
| 	}
 | |
| 	sources, err := ActiveSources(SSPI)
 | |
| 	if err != nil {
 | |
| 		log.Error("ActiveSources: %v", err)
 | |
| 		return false
 | |
| 	}
 | |
| 	return len(sources) > 0
 | |
| }
 | |
| 
 | |
| // GetSourceByID returns login source by given ID.
 | |
| func GetSourceByID(id int64) (*Source, error) {
 | |
| 	source := new(Source)
 | |
| 	if id == 0 {
 | |
| 		source.Cfg = registeredConfigs[NoType]()
 | |
| 		// Set this source to active
 | |
| 		// FIXME: allow disabling of db based password authentication in future
 | |
| 		source.IsActive = true
 | |
| 		return source, nil
 | |
| 	}
 | |
| 
 | |
| 	has, err := db.GetEngine(db.DefaultContext).ID(id).Get(source)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	} else if !has {
 | |
| 		return nil, ErrSourceNotExist{id}
 | |
| 	}
 | |
| 	return source, nil
 | |
| }
 | |
| 
 | |
| // UpdateSource updates a Source record in DB.
 | |
| func UpdateSource(source *Source) error {
 | |
| 	var originalSource *Source
 | |
| 	if source.IsOAuth2() {
 | |
| 		// keep track of the original values so we can restore in case of errors while registering OAuth2 providers
 | |
| 		var err error
 | |
| 		if originalSource, err = GetSourceByID(source.ID); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	has, err := db.GetEngine(db.DefaultContext).Where("name=? AND id!=?", source.Name, source.ID).Exist(new(Source))
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	} else if has {
 | |
| 		return ErrSourceAlreadyExist{source.Name}
 | |
| 	}
 | |
| 
 | |
| 	_, err = db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(source)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if !source.IsActive {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	if settable, ok := source.Cfg.(SourceSettable); ok {
 | |
| 		settable.SetAuthSource(source)
 | |
| 	}
 | |
| 
 | |
| 	registerableSource, ok := source.Cfg.(RegisterableSource)
 | |
| 	if !ok {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	err = registerableSource.RegisterSource()
 | |
| 	if err != nil {
 | |
| 		// restore original values since we cannot update the provider it self
 | |
| 		if _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(originalSource); err != nil {
 | |
| 			log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
 | |
| 		}
 | |
| 	}
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| // CountSources returns number of login sources.
 | |
| func CountSources() int64 {
 | |
| 	count, _ := db.GetEngine(db.DefaultContext).Count(new(Source))
 | |
| 	return count
 | |
| }
 | |
| 
 | |
| // ErrSourceNotExist represents a "SourceNotExist" kind of error.
 | |
| type ErrSourceNotExist struct {
 | |
| 	ID int64
 | |
| }
 | |
| 
 | |
| // IsErrSourceNotExist checks if an error is a ErrSourceNotExist.
 | |
| func IsErrSourceNotExist(err error) bool {
 | |
| 	_, ok := err.(ErrSourceNotExist)
 | |
| 	return ok
 | |
| }
 | |
| 
 | |
| func (err ErrSourceNotExist) Error() string {
 | |
| 	return fmt.Sprintf("login source does not exist [id: %d]", err.ID)
 | |
| }
 | |
| 
 | |
| // Unwrap unwraps this as a ErrNotExist err
 | |
| func (err ErrSourceNotExist) Unwrap() error {
 | |
| 	return util.ErrNotExist
 | |
| }
 | |
| 
 | |
| // ErrSourceAlreadyExist represents a "SourceAlreadyExist" kind of error.
 | |
| type ErrSourceAlreadyExist struct {
 | |
| 	Name string
 | |
| }
 | |
| 
 | |
| // IsErrSourceAlreadyExist checks if an error is a ErrSourceAlreadyExist.
 | |
| func IsErrSourceAlreadyExist(err error) bool {
 | |
| 	_, ok := err.(ErrSourceAlreadyExist)
 | |
| 	return ok
 | |
| }
 | |
| 
 | |
| func (err ErrSourceAlreadyExist) Error() string {
 | |
| 	return fmt.Sprintf("login source already exists [name: %s]", err.Name)
 | |
| }
 | |
| 
 | |
| // Unwrap unwraps this as a ErrExist err
 | |
| func (err ErrSourceAlreadyExist) Unwrap() error {
 | |
| 	return util.ErrAlreadyExist
 | |
| }
 | |
| 
 | |
| // ErrSourceInUse represents a "SourceInUse" kind of error.
 | |
| type ErrSourceInUse struct {
 | |
| 	ID int64
 | |
| }
 | |
| 
 | |
| // IsErrSourceInUse checks if an error is a ErrSourceInUse.
 | |
| func IsErrSourceInUse(err error) bool {
 | |
| 	_, ok := err.(ErrSourceInUse)
 | |
| 	return ok
 | |
| }
 | |
| 
 | |
| func (err ErrSourceInUse) Error() string {
 | |
| 	return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID)
 | |
| }
 |