Fix validation issues and test setup

This commit is contained in:
steve-w
2026-04-06 02:09:36 +00:00
parent dd6b76e665
commit 18af835645
3 changed files with 32 additions and 15 deletions

23
main.py
View File

@@ -1,16 +1,21 @@
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import os
import sqlite3 import sqlite3
from datetime import datetime from datetime import datetime
from typing import Optional, List from typing import Optional, List
app = FastAPI(title="Shopping List API", version="0.1.0") app = FastAPI(title="Shopping List API", version="0.1.0")
DB_PATH = "shopping.db"
def get_db_path():
return os.environ.get("DB_PATH", "shopping.db")
def get_db(): def get_db():
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(get_db_path())
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys = ON")
return conn return conn
def init_db(): def init_db():
@@ -98,7 +103,7 @@ class ListWithItems(ListResponse):
@app.get("/", tags=["meta"]) @app.get("/", tags=["meta"])
def read_root(): def read_root():
return {"message": "Shopping List API", "version": "0.1.0"} return {"message": "Shopping List API"}
@app.post("/products", response_model=ProductResponse, status_code=201, tags=["products"]) @app.post("/products", response_model=ProductResponse, status_code=201, tags=["products"])
def create_product(product: Product): def create_product(product: Product):
@@ -188,8 +193,8 @@ def get_list(id: int):
product_id=row['product_id'], product_id=row['product_id'],
quantity=row['quantity'], quantity=row['quantity'],
added_at=row['added_at'], added_at=row['added_at'],
product_name=row.get('product_name'), product_name=row['product_name'],
product_sku=row.get('product_sku') product_sku=row['product_sku']
)) ))
return ListWithItems(**dict(lst_row), items=items) return ListWithItems(**dict(lst_row), items=items)
@@ -238,8 +243,8 @@ def add_item(list_id: int, item: ListItemCreate):
product_id=row['product_id'], product_id=row['product_id'],
quantity=row['quantity'], quantity=row['quantity'],
added_at=row['added_at'], added_at=row['added_at'],
product_name=row.get('product_name'), product_name=row['product_name'],
product_sku=row.get('product_sku') product_sku=row['product_sku']
) )
@app.patch("/lists/{list_id}/items/{item_id}", tags=["items"]) @app.patch("/lists/{list_id}/items/{item_id}", tags=["items"])
@@ -293,8 +298,8 @@ def list_items(list_id: int):
product_id=row['product_id'], product_id=row['product_id'],
quantity=row['quantity'], quantity=row['quantity'],
added_at=row['added_at'], added_at=row['added_at'],
product_name=row.get('product_name'), product_name=row['product_name'],
product_sku=row.get('product_sku') product_sku=row['product_sku']
) )
for row in rows for row in rows
] ]

View File

@@ -1,4 +1,5 @@
fastapi fastapi
uvicorn uvicorn
sqlalchemy
pydantic pydantic
pytest
httpx

View File

@@ -1,16 +1,27 @@
import os
import tempfile
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from main import app, init_db from main import app, init_db
import sqlite3
import os
client = TestClient(app)
def setup_module(module): def setup_module(module):
# Use an in-memory database for tests fd, path = tempfile.mkstemp(suffix=".db")
os.environ["DB_PATH"] = ":memory:" os.close(fd)
module.TEST_DB_PATH = path
os.environ["DB_PATH"] = path
init_db() init_db()
def teardown_module(module):
path = getattr(module, "TEST_DB_PATH", None)
if path and os.path.exists(path):
os.remove(path)
client = TestClient(app)
def test_root(): def test_root():
r = client.get("/") r = client.get("/")
assert r.status_code == 200 assert r.status_code == 200