Skip to content

Commit 8bd5e2c

Browse files
Enhance product and cart item management
Added checks for seller_id and product stock in create_product and add_cart_item functions. Updated remove_cart_item to return a boolean and modified get_address function to accept address_id instead of user_id.
1 parent 7cd7679 commit 8bd5e2c

1 file changed

Lines changed: 38 additions & 8 deletions

File tree

app/crud.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,12 @@ def get_product(db: Session, product_id: int) -> Optional[models.Product]:
6565
return db.query(models.Product).filter(models.Product.id == product_id).first()
6666

6767
def create_product(db: Session, product: schemas.ProductBase, seller_id: int = None) -> models.Product:
68+
# Prevent client-provided seller_id from being trusted. Start from product dict
6869
product_data = product.dict()
69-
if seller_id:
70+
# Remove any seller_id supplied by client to avoid privilege escalation
71+
product_data.pop("seller_id", None)
72+
# If an explicit seller_id is provided by the server-side caller, use it
73+
if seller_id is not None:
7074
product_data["seller_id"] = seller_id
7175
db_product = models.Product(**product_data)
7276
db.add(db_product)
@@ -118,6 +122,13 @@ def create_cart(db: Session, user_id: int) -> models.Cart:
118122
return cart
119123

120124
def add_cart_item(db: Session, cart_id: int, item: schemas.CartItemBase) -> models.CartItem:
125+
product = db.query(models.Product).filter(models.Product.id == item.product_id).first()
126+
if not product:
127+
raise ValueError("Product not found")
128+
129+
if product.stock is not None and item.quantity > product.stock:
130+
raise ValueError("Insufficient stock")
131+
121132
db_item = models.CartItem(cart_id=cart_id, **item.dict())
122133
db.add(db_item)
123134
db.commit()
@@ -134,10 +145,16 @@ def update_cart_item(db: Session, cart_item_id: int, quantity: int) -> models.Ca
134145

135146
def remove_cart_item(db: Session, cart_item_id: int) -> Optional[models.CartItem]:
136147
item = db.query(models.CartItem).filter(models.CartItem.id == cart_item_id).first()
137-
if item:
138-
db.delete(item)
139-
db.commit()
140-
return item
148+
# if item:
149+
# db.delete(item)
150+
# db.commit()
151+
# return item
152+
if not item:
153+
return False
154+
db.delete(item)
155+
db.commit()
156+
return True
157+
141158

142159
# ORDER CRUD
143160
def create_order(db: Session, order: schemas.OrderBase, user_id: int, items: List[schemas.OrderItemBase]) -> models.Order:
@@ -191,6 +208,16 @@ def create_order_from_cart_for_user(db: Session, user_id: int) -> models.Order:
191208
db.commit()
192209
db.refresh(order)
193210

211+
# Clear the cart items now that the order has been placed
212+
for item in cart_items:
213+
# If cart items have relationships, remove them safely
214+
try:
215+
db.delete(item)
216+
except Exception:
217+
# fallback: ignore deletion error and continue
218+
pass
219+
db.commit()
220+
194221
return order
195222

196223
# CATEGORY CRUD
@@ -236,8 +263,11 @@ def create_address(db: Session, user_id: int, address: schemas.AddressCreate):
236263
def get_addresses(db: Session, user_id: int):
237264
return db.query(models.Address).filter(models.Address.user_id == user_id).all()
238265

239-
def get_address(db: Session, user_id: int):
240-
return db.query(models.Address).filter(models.Address.id == user_id).first()
266+
def get_address(db: Session, address_id: int):
267+
return db.query(models.Address).filter(models.Address.id == address_id).first()
268+
269+
def get_address_by_user(db: Session, user_id: int):
270+
return db.query(models.Address).filter(models.Address.user_id == user_id).first()
241271

242272
def update_address(db: Session, db_address: models.Address, update: schemas.AddressUpdate):
243273
for key, value in update.dict(exclude_unset=True).items():
@@ -318,4 +348,4 @@ def update_shipment_status(db: Session, shipment_id: int, status: str):
318348
return shipment
319349

320350
def get_shipments(db: Session, skip: int = 0, limit: int = 100):
321-
return db.query(models.Shipment).offset(skip).limit(limit).all()
351+
return db.query(models.Shipment).offset(skip).limit(limit).all()

0 commit comments

Comments
 (0)