Add exempt

This commit is contained in:
wcorrales 2021-03-07 12:17:58 -06:00
parent 9e9894dd8d
commit 17b3d3a2ce
3 changed files with 45 additions and 3 deletions

View File

@ -2,3 +2,9 @@
----------------
* Released initial version.
0.2.0 2021-03-05
----------------
* Add exempt.

View File

@ -4,7 +4,7 @@ import os
from urllib.parse import urlparse
from itsdangerous import BadData, SignatureExpired, URLSafeTimedSerializer
from quart import current_app, g, request, session
from quart import Blueprint, current_app, g, request, session
from werkzeug.exceptions import BadRequest
from werkzeug.security import safe_str_cmp
from wtforms import ValidationError
@ -143,6 +143,9 @@ class CSRFProtect:
"""
def __init__(self, app=None):
self._exempt_views = set()
self._exempt_blueprints = set()
if app:
self.init_app(app)
@ -178,6 +181,15 @@ class CSRFProtect:
if not request.endpoint:
return
if request.blueprint in self._exempt_blueprints:
return
view = app.view_functions.get(request.endpoint)
dest = f'{view.__module__}.{view.__name__}'
if dest in self._exempt_views:
return
await self.protect()
async def _get_csrf_token(self):
@ -227,6 +239,30 @@ class CSRFProtect:
g.csrf_valid = True # mark this request as CSRF valid
def exempt(self, view):
"""Mark a view or blueprint to be excluded from CSRF protection.
::
@app.route('/some-view', methods=['POST'])
@csrf.exempt
def some_view():
...
::
bp = Blueprint(...)
csrf.exempt(bp)
"""
if isinstance(view, Blueprint):
self._exempt_blueprints.add(view.name)
return view
if isinstance(view, str):
view_location = view
else:
view_location = '.'.join((view.__module__, view.__name__))
self._exempt_views.add(view_location)
return view
def _error_response(self, reason):
raise CSRFError(reason)

View File

@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
setup(
name='quart-csrf',
version='0.1',
version='0.2',
author='Wagner Corrales',
author_email='wagnerc4@gmail.com',
description='Quart CSRF Protection',
@ -16,7 +16,7 @@ setup(
install_requires=['itsdangerous', 'quart', 'wtforms'],
license='MIT',
classifiers=[
'Development Status :: 1 - Alpha',
'Development Status :: 3 - Alpha',
'Environment :: Web Environment',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',