@@ -65,8 +65,12 @@ def get_product(db: Session, product_id: int) -> Optional[models.Product]:
6565 return db .query (models .Product ).filter (models .Product .id == product_id ).first ()
6666
6767def create_product (db : Session , product : schemas .ProductBase , seller_id : int = None ) -> models .Product :
68+ # Prevent client-provided seller_id from being trusted. Start from product dict
6869 product_data = product .dict ()
69- if seller_id :
70+ # Remove any seller_id supplied by client to avoid privilege escalation
71+ product_data .pop ("seller_id" , None )
72+ # If an explicit seller_id is provided by the server-side caller, use it
73+ if seller_id is not None :
7074 product_data ["seller_id" ] = seller_id
7175 db_product = models .Product (** product_data )
7276 db .add (db_product )
@@ -118,6 +122,13 @@ def create_cart(db: Session, user_id: int) -> models.Cart:
118122 return cart
119123
120124def add_cart_item (db : Session , cart_id : int , item : schemas .CartItemBase ) -> models .CartItem :
125+ product = db .query (models .Product ).filter (models .Product .id == item .product_id ).first ()
126+ if not product :
127+ raise ValueError ("Product not found" )
128+
129+ if product .stock is not None and item .quantity > product .stock :
130+ raise ValueError ("Insufficient stock" )
131+
121132 db_item = models .CartItem (cart_id = cart_id , ** item .dict ())
122133 db .add (db_item )
123134 db .commit ()
@@ -134,10 +145,16 @@ def update_cart_item(db: Session, cart_item_id: int, quantity: int) -> models.Ca
134145
135146def remove_cart_item (db : Session , cart_item_id : int ) -> Optional [models .CartItem ]:
136147 item = db .query (models .CartItem ).filter (models .CartItem .id == cart_item_id ).first ()
137- if item :
138- db .delete (item )
139- db .commit ()
140- return item
148+ # if item:
149+ # db.delete(item)
150+ # db.commit()
151+ # return item
152+ if not item :
153+ return False
154+ db .delete (item )
155+ db .commit ()
156+ return True
157+
141158
142159# ORDER CRUD
143160def create_order (db : Session , order : schemas .OrderBase , user_id : int , items : List [schemas .OrderItemBase ]) -> models .Order :
@@ -191,6 +208,16 @@ def create_order_from_cart_for_user(db: Session, user_id: int) -> models.Order:
191208 db .commit ()
192209 db .refresh (order )
193210
211+ # Clear the cart items now that the order has been placed
212+ for item in cart_items :
213+ # If cart items have relationships, remove them safely
214+ try :
215+ db .delete (item )
216+ except Exception :
217+ # fallback: ignore deletion error and continue
218+ pass
219+ db .commit ()
220+
194221 return order
195222
196223# CATEGORY CRUD
@@ -236,8 +263,11 @@ def create_address(db: Session, user_id: int, address: schemas.AddressCreate):
236263def get_addresses (db : Session , user_id : int ):
237264 return db .query (models .Address ).filter (models .Address .user_id == user_id ).all ()
238265
239- def get_address (db : Session , user_id : int ):
240- return db .query (models .Address ).filter (models .Address .id == user_id ).first ()
266+ def get_address (db : Session , address_id : int ):
267+ return db .query (models .Address ).filter (models .Address .id == address_id ).first ()
268+
269+ def get_address_by_user (db : Session , user_id : int ):
270+ return db .query (models .Address ).filter (models .Address .user_id == user_id ).first ()
241271
242272def update_address (db : Session , db_address : models .Address , update : schemas .AddressUpdate ):
243273 for key , value in update .dict (exclude_unset = True ).items ():
@@ -318,4 +348,4 @@ def update_shipment_status(db: Session, shipment_id: int, status: str):
318348 return shipment
319349
320350def get_shipments (db : Session , skip : int = 0 , limit : int = 100 ):
321- return db .query (models .Shipment ).offset (skip ).limit (limit ).all ()
351+ return db .query (models .Shipment ).offset (skip ).limit (limit ).all ()
0 commit comments