diff --git a/routers/web/auth/oauth.go b/routers/web/auth/oauth.go index e8e5d2c54b..1b75004623 100644 --- a/routers/web/auth/oauth.go +++ b/routers/web/auth/oauth.go @@ -489,7 +489,7 @@ func AuthorizeOAuth(ctx *context.Context) { }, form.RedirectURI) return } - if err := ctx.Session.Set("CodeChallengeMethod", form.CodeChallenge); err != nil { + if err := ctx.Session.Set("CodeChallenge", form.CodeChallenge); err != nil { handleAuthorizeError(ctx, AuthorizeError{ ErrorCode: ErrorCodeServerError, ErrorDescription: "cannot set code challenge", diff --git a/tests/integration/oauth_test.go b/tests/integration/oauth_test.go index 2b44863ec2..68d168bde5 100644 --- a/tests/integration/oauth_test.go +++ b/tests/integration/oauth_test.go @@ -1535,3 +1535,64 @@ func TestSignUpViaOAuth2FA(t *testing.T) { // Make sure user has to go through 2FA. assert.Equal(t, "/user/webauthn", test.RedirectURL(resp)) } + +func TestAccessTokenWithPKCE(t *testing.T) { + defer tests.PrepareTestEnv(t)() + + var u *url.URL + t.Run("Grant", func(t *testing.T) { + session := loginUser(t, "user4") + req := NewRequestWithValues(t, "POST", "/login/oauth/grant", map[string]string{ + "_csrf": GetCSRF(t, session, "/login/oauth/authorize?client_id=ce5a1322-42a7-11ed-b878-0242ac120002&redirect_uri=b&response_type=code&code_challenge_method=plain&code_challenge=CODE&state=thestate"), + "client_id": "ce5a1322-42a7-11ed-b878-0242ac120002", + "redirect_uri": "b", + "state": "thestate", + "granted": "true", + }) + resp := session.MakeRequest(t, req, http.StatusSeeOther) + + var err error + u, err = url.Parse(test.RedirectURL(resp)) + require.NoError(t, err) + }) + + t.Run("Incorrect code verfifier", func(t *testing.T) { + req := NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ + "client_id": "ce5a1322-42a7-11ed-b878-0242ac120002", + "code": u.Query().Get("code"), + "code_verifier": "just a guess", + "grant_type": "authorization_code", + "redirect_uri": "b", + }) + resp := MakeRequest(t, req, http.StatusBadRequest) + + var respBody map[string]any + DecodeJSON(t, resp, &respBody) + + if assert.Len(t, respBody, 2) { + assert.Equal(t, "unauthorized_client", respBody["error"]) + assert.Equal(t, "failed PKCE code challenge", respBody["error_description"]) + } + }) + + t.Run("Get access token", func(t *testing.T) { + req := NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ + "client_id": "ce5a1322-42a7-11ed-b878-0242ac120002", + "code": u.Query().Get("code"), + "code_verifier": "CODE", + "grant_type": "authorization_code", + "redirect_uri": "b", + }) + resp := MakeRequest(t, req, http.StatusOK) + + var respBody map[string]any + DecodeJSON(t, resp, &respBody) + + if assert.Len(t, respBody, 4) { + assert.NotEmpty(t, respBody["access_token"]) + assert.NotEmpty(t, respBody["token_type"]) + assert.NotEmpty(t, respBody["expires_in"]) + assert.NotEmpty(t, respBody["refresh_token"]) + } + }) +}