Skip to content

Commit 7860b46

Browse files
committed
fixes #3
1 parent e89bc7e commit 7860b46

File tree

3 files changed

+41
-25
lines changed

3 files changed

+41
-25
lines changed

models.py

+28-19
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,22 @@
66
from core import db
77
import bcrypt
88

9+
910
class User(db.Model):
1011
""" User which will be querying resources from the API.
1112
1213
:param db.Model: Base class for database models.
1314
"""
14-
id = db.Column(db.Integer, primary_key=True)
15+
16+
id = db.Column(db.Integer, primary_key=True)
1517
username = db.Column(db.String(40), unique=True)
16-
hashpw = db.Column(db.String(80))
18+
hashpw = db.Column(db.String(80))
1719

1820
@staticmethod
1921
def find_with_password(username, password, *args, **kwargs):
20-
""" Query the User collection for a record with matching username and password hash.
21-
If only a username is supplied, find the first matching document with that username.
22+
""" Query the User collection for a record with matching username and
23+
password hash. If only a username is supplied, find the first matching
24+
document with that username.
2225
2326
:param username: Username of the user.
2427
:param password: Password of the user.
@@ -28,8 +31,11 @@ def find_with_password(username, password, *args, **kwargs):
2831
user = User.query.filter_by(username=username).first()
2932
if user and password:
3033
encodedpw = password.encode('utf-8')
31-
userhash = user.hashpw.encode('utf-8')
32-
return User.query.filter(User.username == username, User.hashpw == bcrypt.hashpw(encodedpw, userhash)).first()
34+
userhash = user.hashpw.encode('utf-8')
35+
return User.query.filter(
36+
User.username == username,
37+
User.hashpw == bcrypt.hashpw(encodedpw, userhash)
38+
).first()
3339
else:
3440
return user
3541

@@ -113,7 +119,7 @@ def find(id):
113119
def delete(self):
114120
""" Delete existing token. """
115121
db.session.delete(self)
116-
db.session(commit)
122+
db.session.commit()
117123
return self
118124

119125
@staticmethod
@@ -144,16 +150,17 @@ class Token(db.Model):
144150
145151
:param db.Model: Base class for database models.
146152
"""
147-
id = db.Column(db.Integer, primary_key=True)
148-
client_id = db.Column(db.String(40), db.ForeignKey('client.client_id'), nullable=False)
149-
client = db.relationship('Client')
150-
user_id = db.Column(db.Integer, db.ForeignKey('user.id'))
151-
user = db.relationship('User')
152-
token_type = db.Column(db.String(40))
153-
access_token = db.Column(db.String(255), unique=True)
153+
id = db.Column(db.Integer, primary_key=True)
154+
client_id = db.Column(db.String(40), db.ForeignKey('client.client_id'),
155+
nullable=False)
156+
client = db.relationship('Client')
157+
user_id = db.Column(db.Integer, db.ForeignKey('user.id'))
158+
user = db.relationship('User')
159+
token_type = db.Column(db.String(40))
160+
access_token = db.Column(db.String(255), unique=True)
154161
refresh_token = db.Column(db.String(255), unique=True)
155-
expires = db.Column(db.DateTime)
156-
scopes = ['']
162+
expires = db.Column(db.DateTime)
163+
scopes = ['']
157164

158165
@staticmethod
159166
def find(access_token=None, refresh_token=None):
@@ -172,8 +179,10 @@ def find(access_token=None, refresh_token=None):
172179
def save(token, request, *args, **kwargs):
173180
""" Save a new token to the database.
174181
175-
:param token: Token dictionary containing access and refresh tokens, plus token type.
176-
:param request: Request dictionary containing information about the client and user.
182+
:param token: Token dictionary containing access and refresh tokens,
183+
plus token type.
184+
:param request: Request dictionary containing information about the
185+
client and user.
177186
:param *args: Variable length argument list.
178187
:param **kwargs: Arbitrary keyword arguments.
179188
"""
@@ -186,7 +195,7 @@ def save(token, request, *args, **kwargs):
186195
[db.session.delete(t) for t in toks]
187196

188197
expires_in = token.pop('expires_in')
189-
expires = datetime.utcnow() + timedelta(seconds=expires_in)
198+
expires = datetime.utcnow() + timedelta(seconds=expires_in)
190199

191200
tok = Token(
192201
access_token=token['access_token'],

validator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
class MyRequestValidator(OAuth2RequestValidator):
99
""" Defines a custom OAuth2 Request Validator based on the Client, User
10-
and Token models.
10+
and Token models.
1111
12-
:param OAuth2RequestValidator: Overrides the OAuth2RequestValidator.
13-
"""
12+
:param OAuth2RequestValidator: Overrides the OAuth2RequestValidator.
13+
"""
1414
def __init__(self):
1515
self._clientgetter = Client.find
1616
self._usergetter = User.find_with_password

views.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,42 @@
77

88
yoloapi = Blueprint('yoloApi', __name__)
99

10+
1011
@yoloapi.route('/oauth/token', methods=['POST'])
1112
@oauth.token_handler
1213
def access_token(*args, **kwargs):
1314
""" This endpoint is for exchanging/refreshing an access token.
1415
15-
Returns a dictionary or None as the extra credentials for creating the token response.
16+
Returns a dictionary or None as the extra credentials for creating the
17+
token response.
1618
1719
:param *args: Variable length argument list.
1820
:param **kwargs: Arbitrary keyword arguments.
1921
"""
2022
return None
2123

24+
2225
@yoloapi.route('/oauth/revoke', methods=['POST'])
2326
@oauth.revoke_handler
2427
def revoke_token():
2528
""" This endpoint allows a user to revoke their access token."""
2629
pass
2730

31+
2832
@yoloapi.route('/', methods=['GET', 'POST'])
2933
def management():
3034
""" This endpoint is for vieweing and adding users and clients. """
3135
if request.method == 'POST' and request.form['submit'] == 'Add User':
3236
User.save(request.form['username'], request.form['password'])
3337
if request.method == 'POST' and request.form['submit'] == 'Add Client':
3438
Client.generate()
35-
return render_template('management.html', users=User.all(), clients=Client.all())
39+
return render_template('management.html', users=User.all(),
40+
clients=Client.all())
41+
3642

3743
@yoloapi.route('/yolo')
3844
@oauth.require_oauth()
3945
def yolo():
4046
""" This is an example endpoint we are trying to protect. """
41-
return "YOLO! Congraulations, you made it through and accessed the protected resource!"
47+
return "YOLO! Congraulations, you made it through and accessed the " \
48+
"protected resource!"

0 commit comments

Comments
 (0)