Move almost all functions' parameter db.Engine to context.Context (#19748)

* Move almost all functions' parameter db.Engine to context.Context
* remove some unnecessary wrap functions
This commit is contained in:
Lunny Xiao 2022-05-20 22:08:52 +08:00 committed by GitHub
parent d81e31ad78
commit fd7d83ace6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
232 changed files with 1463 additions and 2108 deletions

View file

@ -92,13 +92,9 @@ func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool {
}
// GetGrantByUserID returns a OAuth2Grant by its user and application ID
func (app *OAuth2Application) GetGrantByUserID(userID int64) (*OAuth2Grant, error) {
return app.getGrantByUserID(db.GetEngine(db.DefaultContext), userID)
}
func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant *OAuth2Grant, err error) {
func (app *OAuth2Application) GetGrantByUserID(ctx context.Context, userID int64) (grant *OAuth2Grant, err error) {
grant = new(OAuth2Grant)
if has, err := e.Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
if has, err := db.GetEngine(ctx).Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
return nil, err
} else if !has {
return nil, nil
@ -107,17 +103,13 @@ func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant
}
// CreateGrant generates a grant for an user
func (app *OAuth2Application) CreateGrant(userID int64, scope string) (*OAuth2Grant, error) {
return app.createGrant(db.GetEngine(db.DefaultContext), userID, scope)
}
func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope string) (*OAuth2Grant, error) {
func (app *OAuth2Application) CreateGrant(ctx context.Context, userID int64, scope string) (*OAuth2Grant, error) {
grant := &OAuth2Grant{
ApplicationID: app.ID,
UserID: userID,
Scope: scope,
}
_, err := e.Insert(grant)
err := db.Insert(ctx, grant)
if err != nil {
return nil, err
}
@ -125,13 +117,9 @@ func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope strin
}
// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found.
func GetOAuth2ApplicationByClientID(clientID string) (app *OAuth2Application, err error) {
return getOAuth2ApplicationByClientID(db.GetEngine(db.DefaultContext), clientID)
}
func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Application, err error) {
func GetOAuth2ApplicationByClientID(ctx context.Context, clientID string) (app *OAuth2Application, err error) {
app = new(OAuth2Application)
has, err := e.Where("client_id = ?", clientID).Get(app)
has, err := db.GetEngine(ctx).Where("client_id = ?", clientID).Get(app)
if !has {
return nil, ErrOAuthClientIDInvalid{ClientID: clientID}
}
@ -139,13 +127,9 @@ func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Ap
}
// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found.
func GetOAuth2ApplicationByID(id int64) (app *OAuth2Application, err error) {
return getOAuth2ApplicationByID(db.GetEngine(db.DefaultContext), id)
}
func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, err error) {
func GetOAuth2ApplicationByID(ctx context.Context, id int64) (app *OAuth2Application, err error) {
app = new(OAuth2Application)
has, err := e.ID(id).Get(app)
has, err := db.GetEngine(ctx).ID(id).Get(app)
if err != nil {
return nil, err
}
@ -156,13 +140,9 @@ func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, er
}
// GetOAuth2ApplicationsByUserID returns all oauth2 applications owned by the user
func GetOAuth2ApplicationsByUserID(userID int64) (apps []*OAuth2Application, err error) {
return getOAuth2ApplicationsByUserID(db.GetEngine(db.DefaultContext), userID)
}
func getOAuth2ApplicationsByUserID(e db.Engine, userID int64) (apps []*OAuth2Application, err error) {
func GetOAuth2ApplicationsByUserID(ctx context.Context, userID int64) (apps []*OAuth2Application, err error) {
apps = make([]*OAuth2Application, 0)
err = e.Where("uid = ?", userID).Find(&apps)
err = db.GetEngine(ctx).Where("uid = ?", userID).Find(&apps)
return
}
@ -174,11 +154,7 @@ type CreateOAuth2ApplicationOptions struct {
}
// CreateOAuth2Application inserts a new oauth2 application
func CreateOAuth2Application(opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
return createOAuth2Application(db.GetEngine(db.DefaultContext), opts)
}
func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
func CreateOAuth2Application(ctx context.Context, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
clientID := uuid.New().String()
app := &OAuth2Application{
UID: opts.UserID,
@ -186,7 +162,7 @@ func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (
ClientID: clientID,
RedirectURIs: opts.RedirectURIs,
}
if _, err := e.Insert(app); err != nil {
if err := db.Insert(ctx, app); err != nil {
return nil, err
}
return app, nil
@ -207,9 +183,8 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
return nil, err
}
defer committer.Close()
sess := db.GetEngine(ctx)
app, err := getOAuth2ApplicationByID(sess, opts.ID)
app, err := GetOAuth2ApplicationByID(ctx, opts.ID)
if err != nil {
return nil, err
}
@ -220,7 +195,7 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
app.Name = opts.Name
app.RedirectURIs = opts.RedirectURIs
if err = updateOAuth2Application(sess, app); err != nil {
if err = updateOAuth2Application(ctx, app); err != nil {
return nil, err
}
app.ClientSecret = ""
@ -228,14 +203,15 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
return app, committer.Commit()
}
func updateOAuth2Application(e db.Engine, app *OAuth2Application) error {
if _, err := e.ID(app.ID).Update(app); err != nil {
func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error {
if _, err := db.GetEngine(ctx).ID(app.ID).Update(app); err != nil {
return err
}
return nil
}
func deleteOAuth2Application(sess db.Engine, id, userid int64) error {
func deleteOAuth2Application(ctx context.Context, id, userid int64) error {
sess := db.GetEngine(ctx)
if deleted, err := sess.Delete(&OAuth2Application{ID: id, UID: userid}); err != nil {
return err
} else if deleted == 0 {
@ -269,7 +245,7 @@ func DeleteOAuth2Application(id, userid int64) error {
return err
}
defer committer.Close()
if err := deleteOAuth2Application(db.GetEngine(ctx), id, userid); err != nil {
if err := deleteOAuth2Application(ctx, id, userid); err != nil {
return err
}
return committer.Commit()
@ -328,21 +304,13 @@ func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (redirect
}
// Invalidate deletes the auth code from the database to invalidate this code
func (code *OAuth2AuthorizationCode) Invalidate() error {
return code.invalidate(db.GetEngine(db.DefaultContext))
}
func (code *OAuth2AuthorizationCode) invalidate(e db.Engine) error {
_, err := e.Delete(code)
func (code *OAuth2AuthorizationCode) Invalidate(ctx context.Context) error {
_, err := db.GetEngine(ctx).ID(code.ID).NoAutoCondition().Delete(code)
return err
}
// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation.
func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool {
return code.validateCodeChallenge(verifier)
}
func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool {
switch code.CodeChallengeMethod {
case "S256":
// base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6
@ -360,19 +328,15 @@ func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool
}
// GetOAuth2AuthorizationByCode returns an authorization by its code
func GetOAuth2AuthorizationByCode(code string) (*OAuth2AuthorizationCode, error) {
return getOAuth2AuthorizationByCode(db.GetEngine(db.DefaultContext), code)
}
func getOAuth2AuthorizationByCode(e db.Engine, code string) (auth *OAuth2AuthorizationCode, err error) {
func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth2AuthorizationCode, err error) {
auth = new(OAuth2AuthorizationCode)
if has, err := e.Where("code = ?", code).Get(auth); err != nil {
if has, err := db.GetEngine(ctx).Where("code = ?", code).Get(auth); err != nil {
return nil, err
} else if !has {
return nil, nil
}
auth.Grant = new(OAuth2Grant)
if has, err := e.ID(auth.GrantID).Get(auth.Grant); err != nil {
if has, err := db.GetEngine(ctx).ID(auth.GrantID).Get(auth.Grant); err != nil {
return nil, err
} else if !has {
return nil, nil
@ -401,11 +365,7 @@ func (grant *OAuth2Grant) TableName() string {
}
// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(redirectURI, codeChallenge, codeChallengeMethod string) (*OAuth2AuthorizationCode, error) {
return grant.generateNewAuthorizationCode(db.GetEngine(db.DefaultContext), redirectURI, codeChallenge, codeChallengeMethod)
}
func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(ctx context.Context, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
rBytes, err := util.CryptoRandomBytes(32)
if err != nil {
return &OAuth2AuthorizationCode{}, err
@ -422,23 +382,19 @@ func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
}
if _, err := e.Insert(code); err != nil {
if err := db.Insert(ctx, code); err != nil {
return nil, err
}
return code, nil
}
// IncreaseCounter increases the counter and updates the grant
func (grant *OAuth2Grant) IncreaseCounter() error {
return grant.increaseCount(db.GetEngine(db.DefaultContext))
}
func (grant *OAuth2Grant) increaseCount(e db.Engine) error {
_, err := e.ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
func (grant *OAuth2Grant) IncreaseCounter(ctx context.Context) error {
_, err := db.GetEngine(ctx).ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
if err != nil {
return err
}
updatedGrant, err := getOAuth2GrantByID(e, grant.ID)
updatedGrant, err := GetOAuth2GrantByID(ctx, grant.ID)
if err != nil {
return err
}
@ -457,13 +413,9 @@ func (grant *OAuth2Grant) ScopeContains(scope string) bool {
}
// SetNonce updates the current nonce value of a grant
func (grant *OAuth2Grant) SetNonce(nonce string) error {
return grant.setNonce(db.GetEngine(db.DefaultContext), nonce)
}
func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
func (grant *OAuth2Grant) SetNonce(ctx context.Context, nonce string) error {
grant.Nonce = nonce
_, err := e.ID(grant.ID).Cols("nonce").Update(grant)
_, err := db.GetEngine(ctx).ID(grant.ID).Cols("nonce").Update(grant)
if err != nil {
return err
}
@ -471,13 +423,9 @@ func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
}
// GetOAuth2GrantByID returns the grant with the given ID
func GetOAuth2GrantByID(id int64) (*OAuth2Grant, error) {
return getOAuth2GrantByID(db.GetEngine(db.DefaultContext), id)
}
func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
func GetOAuth2GrantByID(ctx context.Context, id int64) (grant *OAuth2Grant, err error) {
grant = new(OAuth2Grant)
if has, err := e.ID(id).Get(grant); err != nil {
if has, err := db.GetEngine(ctx).ID(id).Get(grant); err != nil {
return nil, err
} else if !has {
return nil, nil
@ -486,18 +434,14 @@ func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
}
// GetOAuth2GrantsByUserID lists all grants of a certain user
func GetOAuth2GrantsByUserID(uid int64) ([]*OAuth2Grant, error) {
return getOAuth2GrantsByUserID(db.GetEngine(db.DefaultContext), uid)
}
func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
func GetOAuth2GrantsByUserID(ctx context.Context, uid int64) ([]*OAuth2Grant, error) {
type joinedOAuth2Grant struct {
Grant *OAuth2Grant `xorm:"extends"`
Application *OAuth2Application `xorm:"extends"`
}
var results *xorm.Rows
var err error
if results, err = e.
if results, err = db.GetEngine(ctx).
Table("oauth2_grant").
Where("user_id = ?", uid).
Join("INNER", "oauth2_application", "application_id = oauth2_application.id").
@ -518,12 +462,8 @@ func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
}
// RevokeOAuth2Grant deletes the grant with grantID and userID
func RevokeOAuth2Grant(grantID, userID int64) error {
return revokeOAuth2Grant(db.GetEngine(db.DefaultContext), grantID, userID)
}
func revokeOAuth2Grant(e db.Engine, grantID, userID int64) error {
_, err := e.Delete(&OAuth2Grant{ID: grantID, UserID: userID})
func RevokeOAuth2Grant(ctx context.Context, grantID, userID int64) error {
_, err := db.DeleteByBean(ctx, &OAuth2Grant{ID: grantID, UserID: userID})
return err
}