cs50/app.py

263 lines
7.5 KiB
Python

import os
import datetime
from cs50 import SQL
from flask import Flask, flash, redirect, render_template, request, session
from flask_session import Session
from werkzeug.security import check_password_hash, generate_password_hash
from helpers import apology, login_required, lookup, usd
# Configure application
app = Flask(__name__)
# Custom filter
app.jinja_env.filters["usd"] = usd
# Configure session to use filesystem (instead of signed cookies)
app.config["SESSION_PERMANENT"] = False
app.config["SESSION_TYPE"] = "filesystem"
Session(app)
# Configure CS50 Library to use SQLite database
db = SQL("sqlite:///finance.db")
@app.after_request
def after_request(response):
"""Ensure responses aren't cached"""
response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
response.headers["Expires"] = 0
response.headers["Pragma"] = "no-cache"
return response
@app.route("/")
@login_required
def index():
"""Show portfolio of stocks"""
user_id = session["user_id"]
cash = db.execute("SELECT * FROM users where id = ?", user_id)
stocks = db.execute(
"SELECT symbol, sum(shares) FROM transactions WHERE user_id = ? GROUP BY symbol HAVING sum(shares) > 0", user_id)
stocks_total = 0
for stock in stocks:
symbol = stock["symbol"]
stock_price = lookup(symbol.upper())
stock["shares"] = stock["sum(shares)"]
stock["price"] = stock_price["price"]
stock["total"] = usd(stock_price["price"] * stock["sum(shares)"])
stocks_total = stocks_total + (stock_price["price"] * stock["sum(shares)"])
total = usd(cash[0]["cash"] + stocks_total)
return render_template("home.html", cash=usd(cash[0]["cash"]), stocks=stocks, total=total)
# return apology("MOFO")
@app.route("/buy", methods=["GET", "POST"])
@login_required
def buy():
"""Buy shares of stock"""
if request.method == "GET":
return render_template("buy.html")
else:
symbol = request.form.get("symbol")
shares = request.form.get("shares")
if not symbol:
return apology("Not Symbol")
stock = lookup(symbol.upper())
if stock == None:
return apology("Symbol not found")
if not shares == "":
transaction_value = int(shares) * stock["price"]
user_id = session["user_id"]
user_cash_db = db.execute("SELECT cash FROM users WHERE id = ?", user_id)
user_cash = user_cash_db[0]["cash"]
if user_cash < transaction_value:
return apology("U broke, m8!")
free_cash = user_cash - transaction_value
db.execute("UPDATE users SET cash = ? WHERE id = ?", free_cash, user_id)
date = datetime.datetime.now()
db.execute("INSERT INTO transactions (user_id, symbol, shares, price, date) VALUES (?, ?, ?, ?, ?)",
user_id, stock["symbol"], shares, stock["price"], date)
flash("Bought!")
return redirect("/")
# return apology("TODO")
@app.route("/history")
@login_required
def history():
"""Show history of transactions"""
stocks = db.execute("SELECT * FROM transactions")
return render_template("history.html", stocks=stocks)
# return apology("TODO")
@app.route("/login", methods=["GET", "POST"])
def login():
"""Log user in"""
# Forget any user_id
session.clear()
# User reached route via POST (as by submitting a form via POST)
if request.method == "POST":
# Ensure username was submitted
if not request.form.get("username"):
return apology("must provide username", 403)
# Ensure password was submitted
elif not request.form.get("password"):
return apology("must provide password", 403)
# Query database for username
rows = db.execute(
"SELECT * FROM users WHERE username = ?", request.form.get("username")
)
# Ensure username exists and password is correct
if len(rows) != 1 or not check_password_hash(
rows[0]["hash"], request.form.get("password")
):
return apology("invalid username and/or password", 403)
# Remember which user has logged in
session["user_id"] = rows[0]["id"]
# Redirect user to home page
return redirect("/")
# User reached route via GET (as by clicking a link or via redirect)
else:
return render_template("login.html")
@app.route("/logout")
def logout():
"""Log user out"""
# Forget any user_id
session.clear()
# Redirect user to login form
return redirect("/")
@app.route("/quote", methods=["GET", "POST"])
@login_required
def quote():
"""Get stock quote."""
if request.method == "POST":
symbol = request.form.get("symbol")
if not symbol:
return apology("Not Symbol")
stock = lookup(symbol.upper())
if stock == None:
return apology("Symbol not found")
stock["price"] = usd(stock["price"])
return render_template("quote.html", stock=stock)
else:
return render_template("quote.html")
# return apology("TODO")
@app.route("/register", methods=["GET", "POST"])
def register():
"""Register user"""
if request.method == "POST":
# Ensure username was submitted
if not request.form.get("username"):
return apology("must provide username", 400)
# Ensure password was submitted
elif not request.form.get("password"):
return apology("must provide password", 400)
# Ensure password repeat matches
if not request.form.get("password") == request.form.get("confirmation"):
return apology("passwords don't match", 400)
username = request.form.get("username")
password = request.form.get("password")
db.execute("INSERT INTO users(username, hash) VALUES(?, ?)",
username, generate_password_hash(password))
return redirect("/")
else:
return render_template("register.html")
# return apology("MOFO")
@app.route("/sell", methods=["GET", "POST"])
@login_required
def sell():
"""Sell shares of stock"""
user_id = session["user_id"]
if request.method == "POST":
symbol = request.form.get("symbol")
shares = int(request.form.get("shares"))
if not symbol:
return apology("Not Symbol")
stock = lookup(symbol.upper())
if stock == None:
return apology("Symbol not found")
stocks = db.execute(
"SELECT symbol, sum(shares) FROM transactions WHERE user_id = ? AND symbol = ?", user_id, symbol)
if stocks[0]["sum(shares)"] < shares:
return apology("Not enough shares")
transaction_value = shares * stock["price"]
user_cash_db = db.execute("SELECT cash FROM users WHERE id = ?", user_id)
user_cash = user_cash_db[0]["cash"]
new_cash = transaction_value + user_cash
db.execute("UPDATE users SET cash = ? WHERE id = ?", new_cash, user_id)
date = datetime.datetime.now()
db.execute("INSERT INTO transactions (user_id, symbol, shares, price, date) VALUES (?, ?, ?, ?, ?)",
user_id, stock["symbol"], -abs(shares), stock["price"], date)
flash("Sold!")
return redirect("/")
else:
stocks = db.execute(
"SELECT symbol FROM transactions WHERE user_id = ? GROUP BY symbol HAVING sum(shares) > 0", user_id)
return render_template("sell.html", stocks=stocks)
# return apology("TODO")