diff --git a/main.py b/main.py index 5f6ca0d..6cd3637 100644 --- a/main.py +++ b/main.py @@ -1,16 +1,21 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field +import os import sqlite3 from datetime import datetime from typing import Optional, List app = FastAPI(title="Shopping List API", version="0.1.0") -DB_PATH = "shopping.db" + +def get_db_path(): + return os.environ.get("DB_PATH", "shopping.db") + def get_db(): - conn = sqlite3.connect(DB_PATH) + conn = sqlite3.connect(get_db_path()) conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys = ON") return conn def init_db(): @@ -98,7 +103,7 @@ class ListWithItems(ListResponse): @app.get("/", tags=["meta"]) def read_root(): - return {"message": "Shopping List API", "version": "0.1.0"} + return {"message": "Shopping List API"} @app.post("/products", response_model=ProductResponse, status_code=201, tags=["products"]) def create_product(product: Product): @@ -188,8 +193,8 @@ def get_list(id: int): product_id=row['product_id'], quantity=row['quantity'], added_at=row['added_at'], - product_name=row.get('product_name'), - product_sku=row.get('product_sku') + product_name=row['product_name'], + product_sku=row['product_sku'] )) return ListWithItems(**dict(lst_row), items=items) @@ -238,8 +243,8 @@ def add_item(list_id: int, item: ListItemCreate): product_id=row['product_id'], quantity=row['quantity'], added_at=row['added_at'], - product_name=row.get('product_name'), - product_sku=row.get('product_sku') + product_name=row['product_name'], + product_sku=row['product_sku'] ) @app.patch("/lists/{list_id}/items/{item_id}", tags=["items"]) @@ -293,8 +298,8 @@ def list_items(list_id: int): product_id=row['product_id'], quantity=row['quantity'], added_at=row['added_at'], - product_name=row.get('product_name'), - product_sku=row.get('product_sku') + product_name=row['product_name'], + product_sku=row['product_sku'] ) for row in rows ] diff --git a/requirements.txt b/requirements.txt index 0a4249c..cf4b680 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ fastapi uvicorn -sqlalchemy pydantic +pytest +httpx diff --git a/test_main.py b/test_main.py index bb99326..a78ffa7 100644 --- a/test_main.py +++ b/test_main.py @@ -1,16 +1,27 @@ +import os +import tempfile + import pytest from fastapi.testclient import TestClient from main import app, init_db -import sqlite3 -import os -client = TestClient(app) def setup_module(module): - # Use an in-memory database for tests - os.environ["DB_PATH"] = ":memory:" + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + module.TEST_DB_PATH = path + os.environ["DB_PATH"] = path init_db() + +def teardown_module(module): + path = getattr(module, "TEST_DB_PATH", None) + if path and os.path.exists(path): + os.remove(path) + + +client = TestClient(app) + def test_root(): r = client.get("/") assert r.status_code == 200