Move almost all functions' parameter db.Engine to context.Context (#19748)
* Move almost all functions' parameter db.Engine to context.Context * remove some unnecessary wrap functions
This commit is contained in:
parent
d81e31ad78
commit
fd7d83ace6
232 changed files with 1463 additions and 2108 deletions
|
@ -92,13 +92,9 @@ func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool {
|
|||
}
|
||||
|
||||
// GetGrantByUserID returns a OAuth2Grant by its user and application ID
|
||||
func (app *OAuth2Application) GetGrantByUserID(userID int64) (*OAuth2Grant, error) {
|
||||
return app.getGrantByUserID(db.GetEngine(db.DefaultContext), userID)
|
||||
}
|
||||
|
||||
func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant *OAuth2Grant, err error) {
|
||||
func (app *OAuth2Application) GetGrantByUserID(ctx context.Context, userID int64) (grant *OAuth2Grant, err error) {
|
||||
grant = new(OAuth2Grant)
|
||||
if has, err := e.Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
|
||||
if has, err := db.GetEngine(ctx).Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
|
@ -107,17 +103,13 @@ func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant
|
|||
}
|
||||
|
||||
// CreateGrant generates a grant for an user
|
||||
func (app *OAuth2Application) CreateGrant(userID int64, scope string) (*OAuth2Grant, error) {
|
||||
return app.createGrant(db.GetEngine(db.DefaultContext), userID, scope)
|
||||
}
|
||||
|
||||
func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope string) (*OAuth2Grant, error) {
|
||||
func (app *OAuth2Application) CreateGrant(ctx context.Context, userID int64, scope string) (*OAuth2Grant, error) {
|
||||
grant := &OAuth2Grant{
|
||||
ApplicationID: app.ID,
|
||||
UserID: userID,
|
||||
Scope: scope,
|
||||
}
|
||||
_, err := e.Insert(grant)
|
||||
err := db.Insert(ctx, grant)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -125,13 +117,9 @@ func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope strin
|
|||
}
|
||||
|
||||
// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found.
|
||||
func GetOAuth2ApplicationByClientID(clientID string) (app *OAuth2Application, err error) {
|
||||
return getOAuth2ApplicationByClientID(db.GetEngine(db.DefaultContext), clientID)
|
||||
}
|
||||
|
||||
func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Application, err error) {
|
||||
func GetOAuth2ApplicationByClientID(ctx context.Context, clientID string) (app *OAuth2Application, err error) {
|
||||
app = new(OAuth2Application)
|
||||
has, err := e.Where("client_id = ?", clientID).Get(app)
|
||||
has, err := db.GetEngine(ctx).Where("client_id = ?", clientID).Get(app)
|
||||
if !has {
|
||||
return nil, ErrOAuthClientIDInvalid{ClientID: clientID}
|
||||
}
|
||||
|
@ -139,13 +127,9 @@ func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Ap
|
|||
}
|
||||
|
||||
// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found.
|
||||
func GetOAuth2ApplicationByID(id int64) (app *OAuth2Application, err error) {
|
||||
return getOAuth2ApplicationByID(db.GetEngine(db.DefaultContext), id)
|
||||
}
|
||||
|
||||
func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, err error) {
|
||||
func GetOAuth2ApplicationByID(ctx context.Context, id int64) (app *OAuth2Application, err error) {
|
||||
app = new(OAuth2Application)
|
||||
has, err := e.ID(id).Get(app)
|
||||
has, err := db.GetEngine(ctx).ID(id).Get(app)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -156,13 +140,9 @@ func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, er
|
|||
}
|
||||
|
||||
// GetOAuth2ApplicationsByUserID returns all oauth2 applications owned by the user
|
||||
func GetOAuth2ApplicationsByUserID(userID int64) (apps []*OAuth2Application, err error) {
|
||||
return getOAuth2ApplicationsByUserID(db.GetEngine(db.DefaultContext), userID)
|
||||
}
|
||||
|
||||
func getOAuth2ApplicationsByUserID(e db.Engine, userID int64) (apps []*OAuth2Application, err error) {
|
||||
func GetOAuth2ApplicationsByUserID(ctx context.Context, userID int64) (apps []*OAuth2Application, err error) {
|
||||
apps = make([]*OAuth2Application, 0)
|
||||
err = e.Where("uid = ?", userID).Find(&apps)
|
||||
err = db.GetEngine(ctx).Where("uid = ?", userID).Find(&apps)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -174,11 +154,7 @@ type CreateOAuth2ApplicationOptions struct {
|
|||
}
|
||||
|
||||
// CreateOAuth2Application inserts a new oauth2 application
|
||||
func CreateOAuth2Application(opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
|
||||
return createOAuth2Application(db.GetEngine(db.DefaultContext), opts)
|
||||
}
|
||||
|
||||
func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
|
||||
func CreateOAuth2Application(ctx context.Context, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
|
||||
clientID := uuid.New().String()
|
||||
app := &OAuth2Application{
|
||||
UID: opts.UserID,
|
||||
|
@ -186,7 +162,7 @@ func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (
|
|||
ClientID: clientID,
|
||||
RedirectURIs: opts.RedirectURIs,
|
||||
}
|
||||
if _, err := e.Insert(app); err != nil {
|
||||
if err := db.Insert(ctx, app); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return app, nil
|
||||
|
@ -207,9 +183,8 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
|
|||
return nil, err
|
||||
}
|
||||
defer committer.Close()
|
||||
sess := db.GetEngine(ctx)
|
||||
|
||||
app, err := getOAuth2ApplicationByID(sess, opts.ID)
|
||||
app, err := GetOAuth2ApplicationByID(ctx, opts.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -220,7 +195,7 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
|
|||
app.Name = opts.Name
|
||||
app.RedirectURIs = opts.RedirectURIs
|
||||
|
||||
if err = updateOAuth2Application(sess, app); err != nil {
|
||||
if err = updateOAuth2Application(ctx, app); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
app.ClientSecret = ""
|
||||
|
@ -228,14 +203,15 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
|
|||
return app, committer.Commit()
|
||||
}
|
||||
|
||||
func updateOAuth2Application(e db.Engine, app *OAuth2Application) error {
|
||||
if _, err := e.ID(app.ID).Update(app); err != nil {
|
||||
func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error {
|
||||
if _, err := db.GetEngine(ctx).ID(app.ID).Update(app); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteOAuth2Application(sess db.Engine, id, userid int64) error {
|
||||
func deleteOAuth2Application(ctx context.Context, id, userid int64) error {
|
||||
sess := db.GetEngine(ctx)
|
||||
if deleted, err := sess.Delete(&OAuth2Application{ID: id, UID: userid}); err != nil {
|
||||
return err
|
||||
} else if deleted == 0 {
|
||||
|
@ -269,7 +245,7 @@ func DeleteOAuth2Application(id, userid int64) error {
|
|||
return err
|
||||
}
|
||||
defer committer.Close()
|
||||
if err := deleteOAuth2Application(db.GetEngine(ctx), id, userid); err != nil {
|
||||
if err := deleteOAuth2Application(ctx, id, userid); err != nil {
|
||||
return err
|
||||
}
|
||||
return committer.Commit()
|
||||
|
@ -328,21 +304,13 @@ func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (redirect
|
|||
}
|
||||
|
||||
// Invalidate deletes the auth code from the database to invalidate this code
|
||||
func (code *OAuth2AuthorizationCode) Invalidate() error {
|
||||
return code.invalidate(db.GetEngine(db.DefaultContext))
|
||||
}
|
||||
|
||||
func (code *OAuth2AuthorizationCode) invalidate(e db.Engine) error {
|
||||
_, err := e.Delete(code)
|
||||
func (code *OAuth2AuthorizationCode) Invalidate(ctx context.Context) error {
|
||||
_, err := db.GetEngine(ctx).ID(code.ID).NoAutoCondition().Delete(code)
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation.
|
||||
func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool {
|
||||
return code.validateCodeChallenge(verifier)
|
||||
}
|
||||
|
||||
func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool {
|
||||
switch code.CodeChallengeMethod {
|
||||
case "S256":
|
||||
// base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6
|
||||
|
@ -360,19 +328,15 @@ func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool
|
|||
}
|
||||
|
||||
// GetOAuth2AuthorizationByCode returns an authorization by its code
|
||||
func GetOAuth2AuthorizationByCode(code string) (*OAuth2AuthorizationCode, error) {
|
||||
return getOAuth2AuthorizationByCode(db.GetEngine(db.DefaultContext), code)
|
||||
}
|
||||
|
||||
func getOAuth2AuthorizationByCode(e db.Engine, code string) (auth *OAuth2AuthorizationCode, err error) {
|
||||
func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth2AuthorizationCode, err error) {
|
||||
auth = new(OAuth2AuthorizationCode)
|
||||
if has, err := e.Where("code = ?", code).Get(auth); err != nil {
|
||||
if has, err := db.GetEngine(ctx).Where("code = ?", code).Get(auth); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
}
|
||||
auth.Grant = new(OAuth2Grant)
|
||||
if has, err := e.ID(auth.GrantID).Get(auth.Grant); err != nil {
|
||||
if has, err := db.GetEngine(ctx).ID(auth.GrantID).Get(auth.Grant); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
|
@ -401,11 +365,7 @@ func (grant *OAuth2Grant) TableName() string {
|
|||
}
|
||||
|
||||
// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database
|
||||
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(redirectURI, codeChallenge, codeChallengeMethod string) (*OAuth2AuthorizationCode, error) {
|
||||
return grant.generateNewAuthorizationCode(db.GetEngine(db.DefaultContext), redirectURI, codeChallenge, codeChallengeMethod)
|
||||
}
|
||||
|
||||
func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
|
||||
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(ctx context.Context, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
|
||||
rBytes, err := util.CryptoRandomBytes(32)
|
||||
if err != nil {
|
||||
return &OAuth2AuthorizationCode{}, err
|
||||
|
@ -422,23 +382,19 @@ func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI,
|
|||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
}
|
||||
if _, err := e.Insert(code); err != nil {
|
||||
if err := db.Insert(ctx, code); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// IncreaseCounter increases the counter and updates the grant
|
||||
func (grant *OAuth2Grant) IncreaseCounter() error {
|
||||
return grant.increaseCount(db.GetEngine(db.DefaultContext))
|
||||
}
|
||||
|
||||
func (grant *OAuth2Grant) increaseCount(e db.Engine) error {
|
||||
_, err := e.ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
|
||||
func (grant *OAuth2Grant) IncreaseCounter(ctx context.Context) error {
|
||||
_, err := db.GetEngine(ctx).ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updatedGrant, err := getOAuth2GrantByID(e, grant.ID)
|
||||
updatedGrant, err := GetOAuth2GrantByID(ctx, grant.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -457,13 +413,9 @@ func (grant *OAuth2Grant) ScopeContains(scope string) bool {
|
|||
}
|
||||
|
||||
// SetNonce updates the current nonce value of a grant
|
||||
func (grant *OAuth2Grant) SetNonce(nonce string) error {
|
||||
return grant.setNonce(db.GetEngine(db.DefaultContext), nonce)
|
||||
}
|
||||
|
||||
func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
|
||||
func (grant *OAuth2Grant) SetNonce(ctx context.Context, nonce string) error {
|
||||
grant.Nonce = nonce
|
||||
_, err := e.ID(grant.ID).Cols("nonce").Update(grant)
|
||||
_, err := db.GetEngine(ctx).ID(grant.ID).Cols("nonce").Update(grant)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -471,13 +423,9 @@ func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
|
|||
}
|
||||
|
||||
// GetOAuth2GrantByID returns the grant with the given ID
|
||||
func GetOAuth2GrantByID(id int64) (*OAuth2Grant, error) {
|
||||
return getOAuth2GrantByID(db.GetEngine(db.DefaultContext), id)
|
||||
}
|
||||
|
||||
func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
|
||||
func GetOAuth2GrantByID(ctx context.Context, id int64) (grant *OAuth2Grant, err error) {
|
||||
grant = new(OAuth2Grant)
|
||||
if has, err := e.ID(id).Get(grant); err != nil {
|
||||
if has, err := db.GetEngine(ctx).ID(id).Get(grant); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
|
@ -486,18 +434,14 @@ func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
|
|||
}
|
||||
|
||||
// GetOAuth2GrantsByUserID lists all grants of a certain user
|
||||
func GetOAuth2GrantsByUserID(uid int64) ([]*OAuth2Grant, error) {
|
||||
return getOAuth2GrantsByUserID(db.GetEngine(db.DefaultContext), uid)
|
||||
}
|
||||
|
||||
func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
|
||||
func GetOAuth2GrantsByUserID(ctx context.Context, uid int64) ([]*OAuth2Grant, error) {
|
||||
type joinedOAuth2Grant struct {
|
||||
Grant *OAuth2Grant `xorm:"extends"`
|
||||
Application *OAuth2Application `xorm:"extends"`
|
||||
}
|
||||
var results *xorm.Rows
|
||||
var err error
|
||||
if results, err = e.
|
||||
if results, err = db.GetEngine(ctx).
|
||||
Table("oauth2_grant").
|
||||
Where("user_id = ?", uid).
|
||||
Join("INNER", "oauth2_application", "application_id = oauth2_application.id").
|
||||
|
@ -518,12 +462,8 @@ func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
|
|||
}
|
||||
|
||||
// RevokeOAuth2Grant deletes the grant with grantID and userID
|
||||
func RevokeOAuth2Grant(grantID, userID int64) error {
|
||||
return revokeOAuth2Grant(db.GetEngine(db.DefaultContext), grantID, userID)
|
||||
}
|
||||
|
||||
func revokeOAuth2Grant(e db.Engine, grantID, userID int64) error {
|
||||
_, err := e.Delete(&OAuth2Grant{ID: grantID, UserID: userID})
|
||||
func RevokeOAuth2Grant(ctx context.Context, grantID, userID int64) error {
|
||||
_, err := db.DeleteByBean(ctx, &OAuth2Grant{ID: grantID, UserID: userID})
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue