98 lines
No EOL
3.7 KiB
Python
98 lines
No EOL
3.7 KiB
Python
import unittest
|
|
from src.anonchat import app, db, csrf
|
|
import os
|
|
import tempfile
|
|
import re
|
|
|
|
class CSRFTestCase(unittest.TestCase):
|
|
def setUp(self):
|
|
# Create a temporary database
|
|
self.db_fd, app.config['DATABASE'] = tempfile.mkstemp()
|
|
app.config['TESTING'] = True
|
|
app.config['WTF_CSRF_ENABLED'] = True # Enable CSRF for our tests
|
|
self.client = app.test_client()
|
|
|
|
with app.app_context():
|
|
db.create_all()
|
|
|
|
def tearDown(self):
|
|
os.close(self.db_fd)
|
|
os.unlink(app.config['DATABASE'])
|
|
|
|
def extract_csrf_token(self, response_data):
|
|
"""Extract CSRF token from the response HTML."""
|
|
match = re.search(r'name="csrf_token" value="(.+?)"', response_data)
|
|
if match:
|
|
return match.group(1)
|
|
return None
|
|
|
|
def test_form_with_csrf(self):
|
|
"""Test that a form with a valid CSRF token works."""
|
|
# Get the initial page with the form
|
|
response = self.client.get('/')
|
|
csrf_token = self.extract_csrf_token(response.data.decode('utf-8'))
|
|
|
|
# Make sure we found a token
|
|
self.assertIsNotNone(csrf_token)
|
|
|
|
# Submit the form with the token
|
|
response = self.client.post('/', data={
|
|
'csrf_token': csrf_token,
|
|
'message': 'Test message'
|
|
}, follow_redirects=True)
|
|
|
|
# Check that the form was processed successfully
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
def test_form_without_csrf(self):
|
|
"""Test that a form without a CSRF token fails."""
|
|
# Submit a form without a CSRF token
|
|
response = self.client.post('/', data={
|
|
'message': 'Test message'
|
|
}, follow_redirects=True)
|
|
|
|
# The request should be rejected with a 400 Bad Request or redirected
|
|
self.assertIn(response.status_code, [400, 302, 200])
|
|
# If redirected to error page, status will be 200 after follow_redirects
|
|
if response.status_code == 200:
|
|
# Ensure we're at an error page
|
|
self.assertIn(b'Error', response.data)
|
|
|
|
def test_admin_form_with_csrf(self):
|
|
"""Test that an admin form with a valid CSRF token works."""
|
|
# Get the initial login page with the form
|
|
response = self.client.get('/admin')
|
|
csrf_token = self.extract_csrf_token(response.data.decode('utf-8'))
|
|
|
|
# Make sure we found a token
|
|
self.assertIsNotNone(csrf_token)
|
|
|
|
# Submit the form with the token (login will fail but that's OK)
|
|
response = self.client.post('/admin', data={
|
|
'csrf_token': csrf_token,
|
|
'username': 'test',
|
|
'password': 'test'
|
|
}, follow_redirects=True)
|
|
|
|
# We should get to the login page again (bad credentials)
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertIn(b'Invalid credentials', response.data)
|
|
|
|
def test_admin_form_without_csrf(self):
|
|
"""Test that an admin form without a CSRF token fails."""
|
|
# Submit a form without a CSRF token
|
|
response = self.client.post('/admin', data={
|
|
'username': 'test',
|
|
'password': 'test'
|
|
}, follow_redirects=True)
|
|
|
|
# The request should be rejected and redirect to login
|
|
self.assertEqual(response.status_code, 200)
|
|
# Check that we have a form on the page (we've been redirected to login)
|
|
self.assertIn(b'<form method="POST">', response.data)
|
|
self.assertIn(b'<input type="hidden" name="csrf_token"', response.data)
|
|
# Since we've been redirected, the flash message might not be visible in HTML
|
|
# So we'll just check that we're at the login page with a form
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |