diff --git a/backend/app/models.py b/backend/app/models.py index 9cace57..142d734 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -89,7 +89,7 @@ class User(UserMixin, db.Model): .join(Post.author.of_type(Author)) .join(Author.followers.of_type(Follower), isouter=True) .where(sa.or_( - Follower.id == self.id + Follower.id == self.id, Author.id == self.id )) .group_by(Post) diff --git a/backend/config.py b/backend/config.py index d05ad60..e1efeb1 100644 --- a/backend/config.py +++ b/backend/config.py @@ -5,8 +5,8 @@ basedir = os.path.abspath(os.path.dirname(__file__)) class Config: SECRET_KEY = os.environ.get('FLASK_SECRET_KEY') or 'flasksk' - #SQLALCHEMY_DATABASE_URI = 'sqlite:///' + os.path.join(basedir, 'zapp.db') - SQLALCHEMY_DATABASE_URI = 'mariadb+mariadbconnector://flasku:' + os.environ.get('MYSQL_PASSWORD') + '@db:3306/flask' + SQLALCHEMY_DATABASE_URI = 'sqlite:///' + os.path.join(basedir, 'zapp.db') + #SQLALCHEMY_DATABASE_URI = 'mariadb+mariadbconnector://flasku:' + os.environ.get('MYSQL_PASSWORD') + '@db:3306/flask' #MAIL_SERVER = 'pmb' MAIL_SERVER = '' diff --git a/backend/tests.py b/backend/tests.py index 1b31bb1..305ed03 100644 --- a/backend/tests.py +++ b/backend/tests.py @@ -2,4 +2,94 @@ import os os.environ['DATABASE_URL'] = 'sqlite://' from datetime import datetime, timezone, timedelta +import unittest +from app import app, db +from app.models import User, Post + + +class UserModelCase(unittest.TestCase): + def setUp(self): + self.app_context = app.app_context() + self.app_context.push() + db.create_all() + + def tearDown(self): + db.session.remove() + db.drop_all() + self.app_context.pop() + + def test_password_hashing(self): + u = User(username='susan', email='susan@example.com') + u.set_password('cat') + self.assertFalse(u.check_password('dog')) + self.assertTrue(u.check_password('cat')) + + def test_follow(self): + u1 = User(username='john', email='john@example.com') + u2 = User(username='susan', email='susan@example.com') + db.session.add(u1) + db.session.add(u2) + db.session.commit() + following = db.session.scalars(u1.following.select()).all() + followers = db.session.scalars(u2.followers.select()).all() + self.assertEqual(following, []) + self.assertEqual(followers, []) + + u1.follow(u2) + db.session.commit() + self.assertTrue(u1.is_following(u2)) + self.assertEqual(u1.following_count(), 1) + self.assertEqual(u2.followers_count(), 1) + u1_following = db.session.scalars(u1.following.select()).all() + u2_followers = db.session.scalars(u2.followers.select()).all() + self.assertEqual(u1_following[0].username, 'susan') + self.assertEqual(u2_followers[0].username, 'john') + + u1.unfollow(u2) + db.session.commit() + self.assertFalse(u1.is_following(u2)) + self.assertEqual(u1.following_count(), 0) + self.assertEqual(u2.followers_count(), 0) + + def test_follow_posts(self): + # create four users + u1 = User(username='john', email='john@example.com') + u2 = User(username='susan', email='susan@example.com') + u3 = User(username='mary', email='mary@example.com') + u4 = User(username='david', email='david@example.com') + db.session.add_all([u1, u2, u3, u4]) + + # create four posts + now = datetime.now(timezone.utc) + p1 = Post(body="post from john", author=u1, + timestamp=now + timedelta(seconds=1)) + p2 = Post(body="post from susan", author=u2, + timestamp=now + timedelta(seconds=4)) + p3 = Post(body="post from mary", author=u3, + timestamp=now + timedelta(seconds=3)) + p4 = Post(body="post from david", author=u4, + timestamp=now + timedelta(seconds=2)) + db.session.add_all([p1, p2, p3, p4]) + db.session.commit() + + # setup the followers + u1.follow(u2) # john follows susan + u1.follow(u4) # john follows david + u2.follow(u3) # susan follows mary + u3.follow(u4) # mary follows david + db.session.commit() + + # check the following posts of each user + f1 = db.session.scalars(u1.following_posts()).all() + f2 = db.session.scalars(u2.following_posts()).all() + f3 = db.session.scalars(u3.following_posts()).all() + f4 = db.session.scalars(u4.following_posts()).all() + self.assertEqual(f1, [p2, p4, p1]) + self.assertEqual(f2, [p2, p3]) + self.assertEqual(f3, [p3, p4]) + self.assertEqual(f4, [p4]) + + +if __name__ == '__main__': + unittest.main(verbosity=2)