Fix validation issues and test setup
This commit is contained in:
23
main.py
23
main.py
@@ -1,16 +1,21 @@
|
|||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
app = FastAPI(title="Shopping List API", version="0.1.0")
|
app = FastAPI(title="Shopping List API", version="0.1.0")
|
||||||
|
|
||||||
DB_PATH = "shopping.db"
|
|
||||||
|
def get_db_path():
|
||||||
|
return os.environ.get("DB_PATH", "shopping.db")
|
||||||
|
|
||||||
|
|
||||||
def get_db():
|
def get_db():
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = sqlite3.connect(get_db_path())
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
|
conn.execute("PRAGMA foreign_keys = ON")
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
def init_db():
|
def init_db():
|
||||||
@@ -98,7 +103,7 @@ class ListWithItems(ListResponse):
|
|||||||
|
|
||||||
@app.get("/", tags=["meta"])
|
@app.get("/", tags=["meta"])
|
||||||
def read_root():
|
def read_root():
|
||||||
return {"message": "Shopping List API", "version": "0.1.0"}
|
return {"message": "Shopping List API"}
|
||||||
|
|
||||||
@app.post("/products", response_model=ProductResponse, status_code=201, tags=["products"])
|
@app.post("/products", response_model=ProductResponse, status_code=201, tags=["products"])
|
||||||
def create_product(product: Product):
|
def create_product(product: Product):
|
||||||
@@ -188,8 +193,8 @@ def get_list(id: int):
|
|||||||
product_id=row['product_id'],
|
product_id=row['product_id'],
|
||||||
quantity=row['quantity'],
|
quantity=row['quantity'],
|
||||||
added_at=row['added_at'],
|
added_at=row['added_at'],
|
||||||
product_name=row.get('product_name'),
|
product_name=row['product_name'],
|
||||||
product_sku=row.get('product_sku')
|
product_sku=row['product_sku']
|
||||||
))
|
))
|
||||||
return ListWithItems(**dict(lst_row), items=items)
|
return ListWithItems(**dict(lst_row), items=items)
|
||||||
|
|
||||||
@@ -238,8 +243,8 @@ def add_item(list_id: int, item: ListItemCreate):
|
|||||||
product_id=row['product_id'],
|
product_id=row['product_id'],
|
||||||
quantity=row['quantity'],
|
quantity=row['quantity'],
|
||||||
added_at=row['added_at'],
|
added_at=row['added_at'],
|
||||||
product_name=row.get('product_name'),
|
product_name=row['product_name'],
|
||||||
product_sku=row.get('product_sku')
|
product_sku=row['product_sku']
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.patch("/lists/{list_id}/items/{item_id}", tags=["items"])
|
@app.patch("/lists/{list_id}/items/{item_id}", tags=["items"])
|
||||||
@@ -293,8 +298,8 @@ def list_items(list_id: int):
|
|||||||
product_id=row['product_id'],
|
product_id=row['product_id'],
|
||||||
quantity=row['quantity'],
|
quantity=row['quantity'],
|
||||||
added_at=row['added_at'],
|
added_at=row['added_at'],
|
||||||
product_name=row.get('product_name'),
|
product_name=row['product_name'],
|
||||||
product_sku=row.get('product_sku')
|
product_sku=row['product_sku']
|
||||||
)
|
)
|
||||||
for row in rows
|
for row in rows
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
sqlalchemy
|
|
||||||
pydantic
|
pydantic
|
||||||
|
pytest
|
||||||
|
httpx
|
||||||
|
|||||||
21
test_main.py
21
test_main.py
@@ -1,16 +1,27 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from main import app, init_db
|
from main import app, init_db
|
||||||
import sqlite3
|
|
||||||
import os
|
|
||||||
|
|
||||||
client = TestClient(app)
|
|
||||||
|
|
||||||
def setup_module(module):
|
def setup_module(module):
|
||||||
# Use an in-memory database for tests
|
fd, path = tempfile.mkstemp(suffix=".db")
|
||||||
os.environ["DB_PATH"] = ":memory:"
|
os.close(fd)
|
||||||
|
module.TEST_DB_PATH = path
|
||||||
|
os.environ["DB_PATH"] = path
|
||||||
init_db()
|
init_db()
|
||||||
|
|
||||||
|
|
||||||
|
def teardown_module(module):
|
||||||
|
path = getattr(module, "TEST_DB_PATH", None)
|
||||||
|
if path and os.path.exists(path):
|
||||||
|
os.remove(path)
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
def test_root():
|
def test_root():
|
||||||
r = client.get("/")
|
r = client.get("/")
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
|
|||||||
Reference in New Issue
Block a user