diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..b7c70d1 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,3 @@ +[*.py] +indent_size = 4 +indent_style = space diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 0000000..1a9564c --- /dev/null +++ b/api/__init__.py @@ -0,0 +1,60 @@ +from flask import Flask, Blueprint +from . import commands, login +from .settings import ProdConfig, Config +from .extensions import db, migrate, jwt +from .exceptions import ApiException + + +def create_app(config: Config = ProdConfig) -> Flask: + """An application factory, as explained here: + http://flask.pocoo.org/docs/patterns/appfactories/. + + :param config_object: The configuration object to use. + """ + app = Flask(__name__.split('.')[0]) + app.url_map.strict_slashes = False + app.config.from_object(config) + register_extensions(app) + register_blueprints(app) + register_errorhandlers(app) + register_shellcontext(app) + register_commands(app) + + return app + + +def register_extensions(app: Flask): + """Register Flask extensions.""" + db.init_app(app) + migrate.init_app(app, db) + jwt.init_app(app) + + +def register_blueprints(app: Flask): + """Register Flask blueprints.""" + pass + + +def register_errorhandlers(app: Flask): + def errorHandler(error: ApiException): + return error.to_response() + + app.errorhandler(ApiException)(errorHandler) + pass + + +def register_shellcontext(app: Flask): + """Register shell context objects.""" + def shell_context(): + """Shell context objects.""" + return { + 'db': db, + } + + app.shell_context_processor(shell_context) + + +def register_commands(app: Flask): + """Register Click commands.""" + app.cli.add_command(commands.clean) + app.cli.add_command(commands.urls) diff --git a/api/commands.py b/api/commands.py new file mode 100644 index 0000000..4c670b5 --- /dev/null +++ b/api/commands.py @@ -0,0 +1,81 @@ +import os +import click + +from flask import current_app +from flask.cli import with_appcontext +from werkzeug.exceptions import MethodNotAllowed, NotFound + +@click.command() +def clean(): + """Remove *.pyc and *.pyo files recursively starting at current directory. + + Borrowed from Flask-Script, converted to use Click. + """ + for dirpath, _, filenames in os.walk('.'): + for filename in filenames: + if filename.endswith('.pyc') or filename.endswith('.pyo'): + full_pathname = os.path.join(dirpath, filename) + click.echo('Removing {}'.format(full_pathname)) + os.remove(full_pathname) + + +@click.command() +@click.option('--url', default=None, + help='Url to test (ex. /static/image.png)') +@click.option('--order', default='rule', + help='Property on Rule to order by (default: rule)') +@with_appcontext +def urls(url, order): + """Display all of the url matching routes for the project. + + Borrowed from Flask-Script, converted to use Click. + """ + rows = [] + column_headers = ('Rule', 'Endpoint', 'Arguments') + + if url: + try: + rule, arguments = ( + current_app.url_map.bind('localhost') + .match(url, return_rule=True)) + rows.append((rule.rule, rule.endpoint, arguments)) + column_length = 3 + except (NotFound, MethodNotAllowed) as e: + rows.append(('<{}>'.format(e), None, None)) + column_length = 1 + else: + rules = sorted( + current_app.url_map.iter_rules(), + key=lambda rule: getattr(rule, order)) + for rule in rules: + rows.append((rule.rule, rule.endpoint, None)) + column_length = 2 + + str_template = '' + table_width = 0 + + if column_length >= 1: + max_rule_length = max(len(r[0]) for r in rows) + max_rule_length = max_rule_length if max_rule_length > 4 else 4 + str_template += '{:' + str(max_rule_length) + '}' + table_width += max_rule_length + + if column_length >= 2: + max_endpoint_length = max(len(str(r[1])) for r in rows) + max_endpoint_length = ( + max_endpoint_length if max_endpoint_length > 8 else 8) + str_template += ' {:' + str(max_endpoint_length) + '}' + table_width += 2 + max_endpoint_length + + if column_length >= 3: + max_arguments_length = max(len(str(r[2])) for r in rows) + max_arguments_length = ( + max_arguments_length if max_arguments_length > 9 else 9) + str_template += ' {:' + str(max_arguments_length) + '}' + table_width += 2 + max_arguments_length + + click.echo(str_template.format(*column_headers[:column_length])) + click.echo('-' * table_width) + + for row in rows: + click.echo(str_template.format(*row[:column_length])) diff --git a/api/constants.py b/api/constants.py new file mode 100644 index 0000000..30b4582 --- /dev/null +++ b/api/constants.py @@ -0,0 +1,2 @@ +API_PASS='API_PASS' +API_USER='super' diff --git a/api/database.py b/api/database.py new file mode 100644 index 0000000..49dbb31 --- /dev/null +++ b/api/database.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +"""Database module, including the SQLAlchemy database object and DB-related utilities.""" +from sqlalchemy.orm import relationship + +from .extensions import db + +# Alias common SQLAlchemy names +Column = db.Column +relationship = relationship +Model = db.Model + +# From Mike Bayer's "Building the app" talk +# https://speakerdeck.com/zzzeek/building-the-app +class SurrogatePK(object): + """A mixin that adds a surrogate integer 'primary key' column named ``id`` \ + to any declarative-mapped class. + """ + + __table_args__ = {'extend_existing': True} + + id = db.Column(db.Integer, primary_key=True) + + @classmethod + def get_by_id(cls, record_id): + """Get record by ID.""" + if any( + (isinstance(record_id, (str, bytes)) and record_id.isdigit(), + isinstance(record_id, (int, float))), + ): + return cls.query.get(int(record_id)) + + +def reference_col(tablename, nullable=False, pk_name='id', **kwargs): + """Column that adds primary key foreign key reference. + + Usage: :: + + category_id = reference_col('category') + category = relationship('Category', backref='categories') + """ + return db.Column( + db.ForeignKey('{0}.{1}'.format(tablename, pk_name)), + nullable=nullable, **kwargs) diff --git a/api/exceptions.py b/api/exceptions.py new file mode 100644 index 0000000..5511173 --- /dev/null +++ b/api/exceptions.py @@ -0,0 +1,22 @@ +from flask import jsonify + +class ApiException(Exception): + status_code = 500 + + def __init__(self, status_code: int, message) -> None: + super().__init__() + self.status_code = status_code + self.message = message + + def to_response(self): + rv = jsonify(self.message) + rv.status_code = self.status_code + return rv + +class NotFoundException(ApiException): + def __init__(self, message) -> None: + super().__init__(404, message) + +class BadRequestException(ApiException): + def __init__(self, message) -> None: + super().__init__(400, message) diff --git a/api/extensions.py b/api/extensions.py new file mode 100644 index 0000000..b858d81 --- /dev/null +++ b/api/extensions.py @@ -0,0 +1,34 @@ +from flask_jwt_extended import JWTManager +from flask_migrate import Migrate +from flask_sqlalchemy import SQLAlchemy, Model + +class CRUDMixin(Model): + """Mixin that adds convenience methods for CRUD (create, read, update, delete) operations.""" + + @classmethod + def create(cls, **kwargs): + """Create a new record and save it the database.""" + instance = cls(**kwargs) + return instance.save() + + def update(self, commit=True, **kwargs): + """Update specific fields of a record.""" + for attr, value in kwargs.items(): + setattr(self, attr, value) + return commit and self.save() or self + + def save(self, commit=True): + """Save the record.""" + db.session.add(self) + if commit: + db.session.commit() + return self + + def delete(self, commit=True): + """Remove the record from the database.""" + db.session.delete(self) + return commit and db.session.commit() + +db = SQLAlchemy(model_class=CRUDMixin) +migrate = Migrate() +jwt = JWTManager() diff --git a/api/settings.py b/api/settings.py new file mode 100644 index 0000000..44b0fd7 --- /dev/null +++ b/api/settings.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +"""Application configuration.""" +import os +from datetime import timedelta + + +class Config(object): + """Base configuration.""" + + SECRET_KEY = os.environ['SECRET_KEY'] + API_PASS = os.environ['API_PASS'] + APP_DIR = os.path.abspath(os.path.dirname(__file__)) # This directory + PROJECT_ROOT = os.path.abspath(os.path.join(APP_DIR, os.pardir)) + DEBUG_TB_INTERCEPT_REDIRECTS = False + CACHE_TYPE = 'simple' # Can be "memcached", "redis", etc. + SQLALCHEMY_TRACK_MODIFICATIONS = False + JWT_AUTH_USERNAME_KEY = 'email' + JWT_AUTH_HEADER_PREFIX = 'Token' + + +class ProdConfig(Config): + """Production configuration.""" + + ENV = 'prod' + DEBUG = False + DB_NAME = os.environ.get('DB_NAME', 'prod.db') + DB_PATH = os.environ.get('DB_PATH', os.path.join(Config.PROJECT_ROOT, DB_NAME)) + SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL', + f'sqlite:///{DB_PATH}') + + +class DevConfig(Config): + """Development configuration.""" + + ENV = 'dev' + DEBUG = True + DB_NAME = 'dev.db' + # Put the db file in project root + DB_PATH = os.path.join(Config.PROJECT_ROOT, DB_NAME) + SQLALCHEMY_DATABASE_URI = f'sqlite:///{DB_PATH}' + CACHE_TYPE = 'simple' # Can be "memcached", "redis", etc. + JWT_ACCESS_TOKEN_EXPIRES = timedelta(60 * 60) + + +class TestConfig(Config): + """Test configuration.""" + + TESTING = True + DEBUG = True + SQLALCHEMY_DATABASE_URI = 'sqlite://' + # For faster tests; needs at least 4 to avoid "ValueError: Invalid rounds" + BCRYPT_LOG_ROUNDS = 4 diff --git a/requirements.txt b/requirements.txt index dd0d172..ec358b3 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/run.py b/run.py new file mode 100644 index 0000000..ce15b40 --- /dev/null +++ b/run.py @@ -0,0 +1,10 @@ +from flask.helpers import get_debug_flag + +from api import create_app +from api.settings import DevConfig, ProdConfig + +CONFIG = DevConfig if get_debug_flag() else ProdConfig +app = create_app(CONFIG) + +if __name__ == "__main__": + app.run()