From 18af835645bed4d8493979a62b8be5a7ab0a826a Mon Sep 17 00:00:00 2001 From: steve-w Date: Mon, 6 Apr 2026 02:09:36 +0000 Subject: [PATCH] Fix validation issues and test setup --- main.py | 23 ++++++++++++++--------- requirements.txt | 3 ++- test_main.py | 21 ++++++++++++++++----- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/main.py b/main.py index 5f6ca0d..6cd3637 100644 --- a/main.py +++ b/main.py @@ -1,16 +1,21 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field +import os import sqlite3 from datetime import datetime from typing import Optional, List app = FastAPI(title="Shopping List API", version="0.1.0") -DB_PATH = "shopping.db" + +def get_db_path(): + return os.environ.get("DB_PATH", "shopping.db") + def get_db(): - conn = sqlite3.connect(DB_PATH) + conn = sqlite3.connect(get_db_path()) conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys = ON") return conn def init_db(): @@ -98,7 +103,7 @@ class ListWithItems(ListResponse): @app.get("/", tags=["meta"]) def read_root(): - return {"message": "Shopping List API", "version": "0.1.0"} + return {"message": "Shopping List API"} @app.post("/products", response_model=ProductResponse, status_code=201, tags=["products"]) def create_product(product: Product): @@ -188,8 +193,8 @@ def get_list(id: int): product_id=row['product_id'], quantity=row['quantity'], added_at=row['added_at'], - product_name=row.get('product_name'), - product_sku=row.get('product_sku') + product_name=row['product_name'], + product_sku=row['product_sku'] )) return ListWithItems(**dict(lst_row), items=items) @@ -238,8 +243,8 @@ def add_item(list_id: int, item: ListItemCreate): product_id=row['product_id'], quantity=row['quantity'], added_at=row['added_at'], - product_name=row.get('product_name'), - product_sku=row.get('product_sku') + product_name=row['product_name'], + product_sku=row['product_sku'] ) @app.patch("/lists/{list_id}/items/{item_id}", tags=["items"]) @@ -293,8 +298,8 @@ def list_items(list_id: int): product_id=row['product_id'], quantity=row['quantity'], added_at=row['added_at'], - product_name=row.get('product_name'), - product_sku=row.get('product_sku') + product_name=row['product_name'], + product_sku=row['product_sku'] ) for row in rows ] diff --git a/requirements.txt b/requirements.txt index 0a4249c..cf4b680 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ fastapi uvicorn -sqlalchemy pydantic +pytest +httpx diff --git a/test_main.py b/test_main.py index bb99326..a78ffa7 100644 --- a/test_main.py +++ b/test_main.py @@ -1,16 +1,27 @@ +import os +import tempfile + import pytest from fastapi.testclient import TestClient from main import app, init_db -import sqlite3 -import os -client = TestClient(app) def setup_module(module): - # Use an in-memory database for tests - os.environ["DB_PATH"] = ":memory:" + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + module.TEST_DB_PATH = path + os.environ["DB_PATH"] = path init_db() + +def teardown_module(module): + path = getattr(module, "TEST_DB_PATH", None) + if path and os.path.exists(path): + os.remove(path) + + +client = TestClient(app) + def test_root(): r = client.get("/") assert r.status_code == 200