Fix validation issues and test setup
This commit is contained in:
23
main.py
23
main.py
@@ -1,16 +1,21 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
app = FastAPI(title="Shopping List API", version="0.1.0")
|
||||
|
||||
DB_PATH = "shopping.db"
|
||||
|
||||
def get_db_path():
|
||||
return os.environ.get("DB_PATH", "shopping.db")
|
||||
|
||||
|
||||
def get_db():
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = sqlite3.connect(get_db_path())
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
return conn
|
||||
|
||||
def init_db():
|
||||
@@ -98,7 +103,7 @@ class ListWithItems(ListResponse):
|
||||
|
||||
@app.get("/", tags=["meta"])
|
||||
def read_root():
|
||||
return {"message": "Shopping List API", "version": "0.1.0"}
|
||||
return {"message": "Shopping List API"}
|
||||
|
||||
@app.post("/products", response_model=ProductResponse, status_code=201, tags=["products"])
|
||||
def create_product(product: Product):
|
||||
@@ -188,8 +193,8 @@ def get_list(id: int):
|
||||
product_id=row['product_id'],
|
||||
quantity=row['quantity'],
|
||||
added_at=row['added_at'],
|
||||
product_name=row.get('product_name'),
|
||||
product_sku=row.get('product_sku')
|
||||
product_name=row['product_name'],
|
||||
product_sku=row['product_sku']
|
||||
))
|
||||
return ListWithItems(**dict(lst_row), items=items)
|
||||
|
||||
@@ -238,8 +243,8 @@ def add_item(list_id: int, item: ListItemCreate):
|
||||
product_id=row['product_id'],
|
||||
quantity=row['quantity'],
|
||||
added_at=row['added_at'],
|
||||
product_name=row.get('product_name'),
|
||||
product_sku=row.get('product_sku')
|
||||
product_name=row['product_name'],
|
||||
product_sku=row['product_sku']
|
||||
)
|
||||
|
||||
@app.patch("/lists/{list_id}/items/{item_id}", tags=["items"])
|
||||
@@ -293,8 +298,8 @@ def list_items(list_id: int):
|
||||
product_id=row['product_id'],
|
||||
quantity=row['quantity'],
|
||||
added_at=row['added_at'],
|
||||
product_name=row.get('product_name'),
|
||||
product_sku=row.get('product_sku')
|
||||
product_name=row['product_name'],
|
||||
product_sku=row['product_sku']
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
fastapi
|
||||
uvicorn
|
||||
sqlalchemy
|
||||
pydantic
|
||||
pytest
|
||||
httpx
|
||||
|
||||
21
test_main.py
21
test_main.py
@@ -1,16 +1,27 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from main import app, init_db
|
||||
import sqlite3
|
||||
import os
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
def setup_module(module):
|
||||
# Use an in-memory database for tests
|
||||
os.environ["DB_PATH"] = ":memory:"
|
||||
fd, path = tempfile.mkstemp(suffix=".db")
|
||||
os.close(fd)
|
||||
module.TEST_DB_PATH = path
|
||||
os.environ["DB_PATH"] = path
|
||||
init_db()
|
||||
|
||||
|
||||
def teardown_module(module):
|
||||
path = getattr(module, "TEST_DB_PATH", None)
|
||||
if path and os.path.exists(path):
|
||||
os.remove(path)
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
def test_root():
|
||||
r = client.get("/")
|
||||
assert r.status_code == 200
|
||||
|
||||
Reference in New Issue
Block a user