diff --git a/plutus-tx-plugin/changelog.d/20260616_plinth_deriving_plugin.md b/plutus-tx-plugin/changelog.d/20260616_plinth_deriving_plugin.md new file mode 100644 index 00000000000..8d527cd8b19 --- /dev/null +++ b/plutus-tx-plugin/changelog.d/20260616_plinth_deriving_plugin.md @@ -0,0 +1,8 @@ +### Added + +- The Plinth plugin now expands `deriving … via Plinth` clauses at parse time, + generating `AsData` pattern synonyms, `Optics` prisms, and `Match` functions + from data declarations. The pass is wired into `Plinth.Plugin`, so any module + compiled with the Plinth plugin gets it automatically — no extra `-fplugin`. + The implementation lives under `PlutusTx.Plugin.Deriving.*`, and the + deriving-via sentinel type is `Plinth` (`PlutusTx.Plugin.Deriving.Via`). diff --git a/plutus-tx-plugin/plutus-tx-plugin.cabal b/plutus-tx-plugin/plutus-tx-plugin.cabal index ab359f994da..b49a019b01a 100644 --- a/plutus-tx-plugin/plutus-tx-plugin.cabal +++ b/plutus-tx-plugin/plutus-tx-plugin.cabal @@ -56,11 +56,24 @@ library hs-source-dirs: src exposed-modules: Plinth.Plugin + PlutusTx.Plugin.Deriving.Via PlutusTx.Compiler.Error PlutusTx.Options PlutusTx.Plugin.Common other-modules: + Paths_plutus_tx_plugin + PlutusTx.Plugin.Deriving + PlutusTx.Plugin.Deriving.Constant.Module + PlutusTx.Plugin.Deriving.Generator.AsData + PlutusTx.Plugin.Deriving.Generator.Common + PlutusTx.Plugin.Deriving.Generator.Match + PlutusTx.Plugin.Deriving.Generator.Optics + PlutusTx.Plugin.Deriving.Hs + PlutusTx.Plugin.Deriving.Hsc + PlutusTx.Plugin.Deriving.Type.Constructor + PlutusTx.Plugin.Deriving.Type.Field + PlutusTx.Plugin.Deriving.Type.Type PlutusTx.Compiler.Binders PlutusTx.Compiler.Builtins PlutusTx.Compiler.Compat @@ -77,6 +90,8 @@ library PlutusTx.Plugin.Boilerplate PlutusTx.Plugin.Unsupported + autogen-modules: Paths_plutus_tx_plugin + build-depends: , array , base >=4.9 && <5 @@ -242,8 +257,10 @@ test-suite frontend-plugin-tests hs-source-dirs: test-frontend-plugin main-is: Spec.hs other-modules: + AsData.Spec Inlineable.Lib Inlineable.Spec + Match.Spec NoStrict.Spec Strict.Spec @@ -254,6 +271,7 @@ test-suite frontend-plugin-tests , plutus-tx-plugin ^>=1.65 , plutus-tx:plutus-tx-testlib , tasty + , tasty-hunit ghc-options: -threaded -rtsopts -with-rtsopts=-N diff --git a/plutus-tx-plugin/src/Plinth/Plugin.hs b/plutus-tx-plugin/src/Plinth/Plugin.hs index 528b90a15f7..1a1e8adf601 100644 --- a/plutus-tx-plugin/src/Plinth/Plugin.hs +++ b/plutus-tx-plugin/src/Plinth/Plugin.hs @@ -5,6 +5,7 @@ module Plinth.Plugin (plugin, plinthc) where import PlutusTx.Options import PlutusTx.Plugin.Boilerplate import PlutusTx.Plugin.Common +import PlutusTx.Plugin.Deriving qualified as Deriving import PlutusTx.Plugin.Unsupported import PlutusTx.Plugin.Utils @@ -18,6 +19,8 @@ plugin :: GHC.Plugin plugin = GHC.defaultPlugin { GHC.driverPlugin = addFlagsAndExts + , -- Expand @deriving … via Plinth@ clauses at parse time. + GHC.parsedResultAction = Deriving.parsedResultAction , GHC.typeCheckResultAction = \cliOpts _modSummary env -> do opts <- case parsePluginOptions (removeBoilerplateOpts cliOpts) of Success o -> pure o diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving.hs b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving.hs new file mode 100644 index 00000000000..315b448ea96 --- /dev/null +++ b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving.hs @@ -0,0 +1,302 @@ +-- | The Plinth @deriving via@ pass. This is /not/ a standalone plugin: it is +-- wired into 'Plinth.Plugin.plugin' as its @parsedResultAction@, so that any +-- module compiled with the Plinth plugin can write +-- +-- > data Shape = Point | Circle Integer Integer +-- > deriving AsData via Plinth +-- > deriving Optics via Plinth +-- +-- without enabling a second plugin. +module PlutusTx.Plugin.Deriving + ( parsedResultAction, + ) +where + +import qualified Control.Monad as Monad +import qualified Control.Monad.IO.Class as IO +import qualified Data.Bifunctor as Bifunctor +import qualified Data.Maybe as Maybe +import qualified PlutusTx.Plugin.Deriving.Generator.AsData as AsData +import qualified PlutusTx.Plugin.Deriving.Generator.Match as Match +import qualified PlutusTx.Plugin.Deriving.Generator.Optics as Optics +import qualified PlutusTx.Plugin.Deriving.Generator.Common as Common +import qualified PlutusTx.Plugin.Deriving.Hsc as Hsc +import qualified GHC.Hs as Ghc +import qualified GHC.Plugins as Ghc + +-- | The @parsedResultAction@ hook: rewrite @deriving … via Plinth@ clauses in +-- the freshly-parsed module into the generated declarations. +parsedResultAction :: + [Ghc.CommandLineOption] -> + Ghc.ModSummary -> + Ghc.ParsedResult -> + Ghc.Hsc Ghc.ParsedResult +parsedResultAction _commandLineOptions modSummary (Ghc.ParsedResult hsParsedModule msgs) = do + let moduleName = Ghc.moduleName $ Ghc.ms_mod modSummary + lHsModule2 <- handleLHsModule moduleName (Ghc.hpm_module hsParsedModule) + pure $ Ghc.ParsedResult hsParsedModule {Ghc.hpm_module = lHsModule2} msgs + +type LHsModule = Ghc.Located (Ghc.HsModule Ghc.GhcPs) + +handleLHsModule :: + Ghc.ModuleName -> + LHsModule -> + Ghc.Hsc LHsModule +handleLHsModule moduleName lHsModule = do + hsModule <- handleHsModule moduleName $ Ghc.unLoc lHsModule + pure $ Ghc.L (Ghc.getLoc lHsModule) hsModule + +handleHsModule :: + Ghc.ModuleName -> + Ghc.HsModule Ghc.GhcPs -> + Ghc.Hsc (Ghc.HsModule Ghc.GhcPs) +handleHsModule moduleName hsModule = do + (lImportDecls, lHsDecls) <- + handleLHsDecls moduleName $ + Ghc.hsmodDecls hsModule + pure + hsModule + { Ghc.hsmodImports = Ghc.hsmodImports hsModule <> lImportDecls, + Ghc.hsmodDecls = lHsDecls + } + +handleLHsDecls :: + Ghc.ModuleName -> + [Ghc.LHsDecl Ghc.GhcPs] -> + Ghc.Hsc ([Ghc.LImportDecl Ghc.GhcPs], [Ghc.LHsDecl Ghc.GhcPs]) +handleLHsDecls moduleName lHsDecls = do + tuples <- mapM (handleLHsDecl moduleName) lHsDecls + pure . Bifunctor.bimap mconcat mconcat $ unzip tuples + +handleLHsDecl :: + Ghc.ModuleName -> + Ghc.LHsDecl Ghc.GhcPs -> + Ghc.Hsc ([Ghc.LImportDecl Ghc.GhcPs], [Ghc.LHsDecl Ghc.GhcPs]) +handleLHsDecl moduleName lHsDecl = case Ghc.unLoc lHsDecl of + Ghc.TyClD xTyClD tyClDecl1 -> do + (mTyClDecl2, (lImportDecls, lHsDecls)) <- handleTyClDecl moduleName tyClDecl1 + case mTyClDecl2 of + Nothing -> + pure (lImportDecls, lHsDecls) + Just tyClDecl2 -> + let newDecl = Ghc.L (Ghc.getLoc lHsDecl) (Ghc.TyClD xTyClD tyClDecl2) + in pure (lImportDecls, newDecl : lHsDecls) + _ -> pure ([], [lHsDecl]) + +handleTyClDecl :: + Ghc.ModuleName -> + Ghc.TyClDecl Ghc.GhcPs -> + Ghc.Hsc + ( Maybe (Ghc.TyClDecl Ghc.GhcPs), + ([Ghc.LImportDecl Ghc.GhcPs], [Ghc.LHsDecl Ghc.GhcPs]) + ) +handleTyClDecl moduleName tyClDecl = case tyClDecl of + Ghc.DataDecl tcdDExt tcdLName tcdTyVars tcdFixity tcdDataDefn -> do + (mHsDataDefn, (lImportDecls, lHsDecls)) <- + handleHsDataDefn + moduleName + tcdLName + tcdTyVars + tcdDataDefn + pure + ( fmap (Ghc.DataDecl tcdDExt tcdLName tcdTyVars tcdFixity) mHsDataDefn, + (lImportDecls, lHsDecls) + ) + _ -> pure (Just tyClDecl, ([], [])) + +handleHsDataDefn :: + Ghc.ModuleName -> + Ghc.LIdP Ghc.GhcPs -> + Ghc.LHsQTyVars Ghc.GhcPs -> + Ghc.HsDataDefn Ghc.GhcPs -> + Ghc.Hsc + ( Maybe (Ghc.HsDataDefn Ghc.GhcPs), + ([Ghc.LImportDecl Ghc.GhcPs], [Ghc.LHsDecl Ghc.GhcPs]) + ) +handleHsDataDefn moduleName lIdP lHsQTyVars hsDataDefn = + case hsDataDefn of + Ghc.HsDataDefn dd_ext dd_ctxt dd_cType dd_kindSig dd_cons dd_derivs -> + do + let consList = case dd_cons of + Ghc.DataTypeCons _ cs -> cs + Ghc.NewTypeCon c -> [c] + + (mHsDeriving, (lImportDecls, lHsDecls)) <- + handleHsDeriving + moduleName + lIdP + lHsQTyVars + consList + dd_derivs + + pure + ( fmap + (\hsDeriving -> Ghc.HsDataDefn dd_ext dd_ctxt dd_cType dd_kindSig dd_cons hsDeriving) + mHsDeriving, + (lImportDecls, lHsDecls) + ) + +handleHsDeriving :: + Ghc.ModuleName -> + Ghc.LIdP Ghc.GhcPs -> + Ghc.LHsQTyVars Ghc.GhcPs -> + [Ghc.LConDecl Ghc.GhcPs] -> + Ghc.HsDeriving Ghc.GhcPs -> + Ghc.Hsc + ( Maybe (Ghc.HsDeriving Ghc.GhcPs), + ( [Ghc.LImportDecl Ghc.GhcPs], + [Ghc.LHsDecl Ghc.GhcPs] + ) + ) +handleHsDeriving moduleName lIdP lHsQTyVars lConDecls hsDeriving = do + (dropOriginal, lHsDerivingClauses, (lImportDecls, lHsDecls)) <- + handleLHsDerivingClauses moduleName lIdP lHsQTyVars lConDecls hsDeriving + pure + ( if dropOriginal then Nothing else Just lHsDerivingClauses, + (lImportDecls, lHsDecls) + ) + +handleLHsDerivingClauses :: + Ghc.ModuleName -> + Ghc.LIdP Ghc.GhcPs -> + Ghc.LHsQTyVars Ghc.GhcPs -> + [Ghc.LConDecl Ghc.GhcPs] -> + Ghc.HsDeriving Ghc.GhcPs -> + Ghc.Hsc + ( Bool, + [Ghc.LHsDerivingClause Ghc.GhcPs], + ( [Ghc.LImportDecl Ghc.GhcPs], + [Ghc.LHsDecl Ghc.GhcPs] + ) + ) +handleLHsDerivingClauses moduleName lIdP lHsQTyVars lConDecls lHsDerivingClauses = + do + tuples <- + mapM + (handleLHsDerivingClause moduleName lIdP lHsQTyVars lConDecls lHsDerivingClauses) + lHsDerivingClauses + let (mClauses, dropFlags, extras) = unzip3 tuples + taggedExtras = zip dropFlags extras + orderedExtras = + fmap snd (filter fst taggedExtras) + <> fmap snd (filter (not . fst) taggedExtras) + pure + ( or dropFlags, + Maybe.catMaybes mClauses, + Bifunctor.bimap mconcat mconcat $ unzip orderedExtras + ) + +handleLHsDerivingClause :: + Ghc.ModuleName -> + Ghc.LIdP Ghc.GhcPs -> + Ghc.LHsQTyVars Ghc.GhcPs -> + [Ghc.LConDecl Ghc.GhcPs] -> + Ghc.HsDeriving Ghc.GhcPs -> + Ghc.LHsDerivingClause Ghc.GhcPs -> + Ghc.Hsc + ( Maybe (Ghc.LHsDerivingClause Ghc.GhcPs), + Bool, + ( [Ghc.LImportDecl Ghc.GhcPs], + [Ghc.LHsDecl Ghc.GhcPs] + ) + ) +handleLHsDerivingClause moduleName lIdP lHsQTyVars lConDecls lHsDerivingClauses lHsDerivingClause = + case Ghc.unLoc lHsDerivingClause of + Ghc.HsDerivingClause _ deriv_clause_strategy deriv_clause_tys + | Common.isPlinthVia deriv_clause_strategy -> do + let nonPlinthClauses = filter + ( \c -> case Ghc.unLoc c of + Ghc.HsDerivingClause _ s _ -> + not (Common.isPlinthVia s) + ) + lHsDerivingClauses + (dropOriginal, lImportDecls, lHsDecls) <- + handleLHsSigTypes moduleName lIdP lHsQTyVars lConDecls nonPlinthClauses + . toLHsSigTypes + $ Ghc.unLoc deriv_clause_tys + pure (Nothing, dropOriginal, (lImportDecls, lHsDecls)) + _ -> pure (Just lHsDerivingClause, False, ([], [])) + +toLHsSigTypes :: Ghc.DerivClauseTys Ghc.GhcPs -> [Ghc.LHsSigType Ghc.GhcPs] +toLHsSigTypes derivClauseTys = case derivClauseTys of + Ghc.DctSingle _ lHsSigType -> [lHsSigType] + Ghc.DctMulti _ lHsSigTypes -> lHsSigTypes + +handleLHsSigTypes :: + Ghc.ModuleName -> + Ghc.LIdP Ghc.GhcPs -> + Ghc.LHsQTyVars Ghc.GhcPs -> + [Ghc.LConDecl Ghc.GhcPs] -> + Ghc.HsDeriving Ghc.GhcPs -> + [Ghc.LHsSigType Ghc.GhcPs] -> + Ghc.Hsc + ( Bool, + [Ghc.LImportDecl Ghc.GhcPs], + [Ghc.LHsDecl Ghc.GhcPs] + ) +handleLHsSigTypes moduleName lIdP lHsQTyVars lConDecls lHsDerivingClauses lHsSigTypes = + do + tuples <- + mapM + (handleLHsSigType moduleName lIdP lHsQTyVars lConDecls lHsDerivingClauses) + lHsSigTypes + let (dropFlags, importLists, declLists) = unzip3 tuples + pure (or dropFlags, mconcat importLists, mconcat declLists) + +handleLHsSigType :: + Ghc.ModuleName -> + Ghc.LIdP Ghc.GhcPs -> + Ghc.LHsQTyVars Ghc.GhcPs -> + [Ghc.LConDecl Ghc.GhcPs] -> + Ghc.HsDeriving Ghc.GhcPs -> + Ghc.LHsSigType Ghc.GhcPs -> + Ghc.Hsc + ( Bool, + [Ghc.LImportDecl Ghc.GhcPs], + [Ghc.LHsDecl Ghc.GhcPs] + ) +handleLHsSigType moduleName lIdP lHsQTyVars lConDecls lHsDerivingClauses lHsSigType = + do + let srcSpan = Ghc.getLocA lHsSigType + (dropOriginal, lImportDecls, lHsDecls) <- case getGenerator lHsSigType of + Just generate -> + generate lHsDerivingClauses moduleName lIdP lHsQTyVars lConDecls srcSpan + Nothing -> Hsc.throwError srcSpan $ Ghc.text "unsupported type class" + + verbose <- isVerbose + Monad.when verbose $ do + IO.liftIO $ do + putStrLn $ replicate 80 '-' + mapM_ (putStrLn . Ghc.showPprUnsafe . Ghc.ppr) lImportDecls + mapM_ (putStrLn . Ghc.showPprUnsafe . Ghc.ppr) lHsDecls + + pure (dropOriginal, lImportDecls, lHsDecls) + +-- | Whether to dump the generated declarations, driven by @-ddump-deriv@. +isVerbose :: Ghc.Hsc Bool +isVerbose = do + dynFlags <- Ghc.getDynFlags + pure $ Ghc.dopt Ghc.Opt_D_dump_deriv dynFlags + +getGenerator :: Ghc.LHsSigType Ghc.GhcPs -> Maybe (Ghc.HsDeriving Ghc.GhcPs -> Common.Generator) +getGenerator lHsSigType = do + className <- getClassName lHsSigType + lookup className generators + +generators :: [(String, Ghc.HsDeriving Ghc.GhcPs -> Common.Generator)] +generators = + [ ("AsData", AsData.generate), + ("Match", Match.generate), + ("Optics", Optics.generate) + ] + +getClassName :: Ghc.LHsSigType Ghc.GhcPs -> Maybe String +getClassName lHsSigType = do + lHsType <- case Ghc.unLoc lHsSigType of + Ghc.HsSig _ _ x -> Just x + lIdP <- case Ghc.unLoc lHsType of + Ghc.HsTyVar _ _ x -> Just x + _ -> Nothing + case Ghc.unLoc lIdP of + Ghc.Unqual x -> Just $ Ghc.occNameString x + _ -> Nothing diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Constant/Module.hs b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Constant/Module.hs new file mode 100644 index 00000000000..a1b8b61eb77 --- /dev/null +++ b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Constant/Module.hs @@ -0,0 +1,53 @@ +module PlutusTx.Plugin.Deriving.Constant.Module + ( controlApplicative, + controlLens, + dataAeson, + dataHashMapStrictInsOrd, + dataMaybe, + dataMonoid, + dataProxy, + dataString, + dataSwagger, + plutusTx, + plutusTxBuiltins, + testQuickCheck, + ) +where + +import qualified GHC.Unit.Module as Ghc + +controlApplicative :: Ghc.ModuleName +controlApplicative = Ghc.mkModuleName "Control.Applicative" + +controlLens :: Ghc.ModuleName +controlLens = Ghc.mkModuleName "Control.Lens" + +dataAeson :: Ghc.ModuleName +dataAeson = Ghc.mkModuleName "Data.Aeson" + +dataHashMapStrictInsOrd :: Ghc.ModuleName +dataHashMapStrictInsOrd = Ghc.mkModuleName "Data.HashMap.Strict.InsOrd" + +dataMaybe :: Ghc.ModuleName +dataMaybe = Ghc.mkModuleName "Data.Maybe" + +dataMonoid :: Ghc.ModuleName +dataMonoid = Ghc.mkModuleName "Data.Monoid" + +dataProxy :: Ghc.ModuleName +dataProxy = Ghc.mkModuleName "Data.Proxy" + +dataString :: Ghc.ModuleName +dataString = Ghc.mkModuleName "Data.String" + +dataSwagger :: Ghc.ModuleName +dataSwagger = Ghc.mkModuleName "Data.Swagger" + +plutusTx :: Ghc.ModuleName +plutusTx = Ghc.mkModuleName "PlutusTx" + +plutusTxBuiltins :: Ghc.ModuleName +plutusTxBuiltins = Ghc.mkModuleName "PlutusTx.Builtins" + +testQuickCheck :: Ghc.ModuleName +testQuickCheck = Ghc.mkModuleName "Test.QuickCheck" diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Generator/AsData.hs b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Generator/AsData.hs new file mode 100644 index 00000000000..3d4e4d7c1d7 --- /dev/null +++ b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Generator/AsData.hs @@ -0,0 +1,504 @@ +{-# LANGUAGE CPP #-} +-- The 'head' calls below are on lists that are non-empty by construction +-- (guarded by the surrounding arity/constructor-count matches). +-- (-Wx-partial does not exist before GHC 9.8, hence -Wno-unrecognised-warning-flags) +{-# OPTIONS_GHC -Wno-unrecognised-warning-flags -Wno-x-partial #-} + +module PlutusTx.Plugin.Deriving.Generator.AsData + ( generate, + ) +where + +import qualified PlutusTx.Plugin.Deriving.Constant.Module as Module +import qualified PlutusTx.Plugin.Deriving.Generator.Common as Common +import qualified PlutusTx.Plugin.Deriving.Hs as Hs +import qualified PlutusTx.Plugin.Deriving.Hsc as Hsc +import qualified PlutusTx.Plugin.Deriving.Type.Constructor as Constructor +import qualified PlutusTx.Plugin.Deriving.Type.Field as Field +import qualified PlutusTx.Plugin.Deriving.Type.Type as Type +import qualified GHC.Hs as Ghc +import qualified GHC.Plugins as Ghc +import qualified GHC.Types.Fixity as Ghc +import qualified GHC.Types.SourceText as Ghc + +-- | Replaces the original data declaration with a newtype backed by +-- 'BuiltinData', generates bidirectional pattern synonyms for each +-- constructor, and derives 'ToData'/'FromData' via GND. +-- +-- Given: +-- +-- > data Example a = Ex1 Integer | Ex2 a a +-- > deriving AsData via Plinth +-- +-- Generates: +-- +-- > newtype Example a = Example_BD PlutusTx.Builtins.BuiltinData +-- > deriving newtype (PlutusTx.ToData, PlutusTx.FromData) +-- > +-- > pattern Ex1 :: Integer -> Example a +-- > pattern Ex1 x0_ <- +-- > Example_BD ((\d_ -> PlutusTx.unsafeFromBuiltinData +-- > (PlutusTx.headBuiltinList (PlutusTx.sndPair (PlutusTx.unsafeDataAsConstr d_)))) -> x0_) +-- > where Ex1 x0_ = Example_BD (PlutusTx.mkConstr 0 [PlutusTx.toBuiltinData x0_]) +-- > +-- > pattern Ex2 :: a -> a -> Example a +-- > pattern Ex2 x0_ x1_ <- +-- > Example_BD ((\d_ -> let args_ = PlutusTx.sndPair (PlutusTx.unsafeDataAsConstr d_) +-- > in (PlutusTx.unsafeFromBuiltinData (PlutusTx.headBuiltinList args_), +-- > ...)) -> (x0_, x1_)) +-- > where Ex2 x0_ x1_ = Example_BD (PlutusTx.mkConstr 1 [...]) +-- > +-- > {-# COMPLETE Ex1, Ex2 #-} +generate :: Ghc.HsDeriving Ghc.GhcPs -> Common.Generator +generate remainingDerivs _moduleName lIdP lHsQTyVars lConDecls srcSpan = do + type_ <- Type.make lIdP lHsQTyVars lConDecls srcSpan + let constructors = Type.constructors type_ + when (null constructors) $ + Hsc.throwError srcSpan $ Ghc.text "AsData requires at least one constructor" + + plutusTx <- Common.makeRandomModule Module.plutusTx + plutusTxBuiltins <- Common.makeRandomModule Module.plutusTxBuiltins + + let lImportDecls = + Hs.importDecls + srcSpan + [ (Module.plutusTx, plutusTx), + (Module.plutusTxBuiltins, plutusTxBuiltins) + ] + + newtypeDecl = + makeNewtypeDecl srcSpan type_ plutusTx plutusTxBuiltins remainingDerivs + + completeDecl = + makeCompleteDecl srcSpan constructors + + patSynDecls <- + mapM + (\(idx, con) -> makePatSynDecl srcSpan type_ con idx plutusTx plutusTxBuiltins) + (zip [0 ..] constructors) + + pure (True, lImportDecls, newtypeDecl : concat patSynDecls <> [completeDecl]) + +when :: Applicative f => Bool -> f () -> f () +when True action = action +when False _ = pure () + +-- | The internal constructor name for the newtype. +internalConName :: Type.Type -> Ghc.OccName +internalConName type_ = + Ghc.mkDataOcc $ + Ghc.occNameString (Ghc.rdrNameOcc (Type.name type_)) <> "_BD" + +-- | Generate: @newtype Example a = Example_BD BuiltinData@ +-- @ deriving newtype (ToData, FromData)@ +makeNewtypeDecl :: + Ghc.SrcSpan -> + Type.Type -> + Ghc.ModuleName -> + Ghc.ModuleName -> + Ghc.HsDeriving Ghc.GhcPs -> + Ghc.LHsDecl Ghc.GhcPs +makeNewtypeDecl srcSpan type_ plutusTx plutusTxBuiltins remainingDerivs = + let tyName = Ghc.rdrNameOcc $ Type.name type_ + lTypeName = Ghc.noLocA $ Ghc.mkRdrUnqual tyName + lConName = Ghc.noLocA $ Ghc.mkRdrUnqual (internalConName type_) + + builtinDataTy = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsTyVar + Ghc.noAnn + Ghc.NotPromoted + (Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ Ghc.Qual plutusTxBuiltins (Ghc.mkTcOcc "BuiltinData")) + + conDecl = + Ghc.noLocA $ + Ghc.ConDeclH98 + { Ghc.con_ext = Ghc.noAnn, + Ghc.con_name = lConName, + Ghc.con_forall = False, + Ghc.con_ex_tvs = [], + Ghc.con_mb_cxt = Nothing, + Ghc.con_args = + Ghc.PrefixCon + [] + [Ghc.HsScaled Hs.unrestrictedArrow builtinDataTy], + Ghc.con_doc = Nothing + } + + -- deriving newtype (ToData, FromData) plus any remaining clauses + gndClause = makeGndClause srcSpan plutusTx + derivs = gndClause : remainingDerivs + + dataDefn = + Ghc.HsDataDefn +#if __GLASGOW_HASKELL__ >= 910 + { Ghc.dd_ext = Ghc.noAnn, +#else + { Ghc.dd_ext = Ghc.noExtField, +#endif + Ghc.dd_ctxt = Nothing, + Ghc.dd_cType = Nothing, + Ghc.dd_kindSig = Nothing, + Ghc.dd_cons = Ghc.NewTypeCon conDecl, + Ghc.dd_derivs = derivs + } + + -- Preserve type variables from the original type + tyVars = mkTyVars srcSpan (Type.variables type_) + + tyClDecl = + Ghc.DataDecl +#if __GLASGOW_HASKELL__ >= 910 + { Ghc.tcdDExt = Ghc.noExtField, +#else + { Ghc.tcdDExt = Ghc.noAnn, +#endif + Ghc.tcdLName = lTypeName, + Ghc.tcdTyVars = tyVars, + Ghc.tcdFixity = Ghc.Prefix, + Ghc.tcdDataDefn = dataDefn + } + in Ghc.noLocA (Ghc.TyClD Ghc.noExtField tyClDecl) + +-- | Build @deriving newtype (PlutusTx.ToData, PlutusTx.FromData)@. +makeGndClause :: + Ghc.SrcSpan -> + Ghc.ModuleName -> + Ghc.LHsDerivingClause Ghc.GhcPs +makeGndClause srcSpan plutusTx = + let strategy = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.NewtypeStrategy Ghc.noAnn + + mkCls occ = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsSig + Ghc.noExtField + Ghc.mkHsOuterImplicit + ( Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsTyVar + Ghc.noAnn + Ghc.NotPromoted + (Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ Ghc.Qual plutusTx occ) + ) + + tys = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.DctMulti Ghc.noExtField [mkCls (Ghc.mkClsOcc "ToData"), mkCls (Ghc.mkClsOcc "FromData"), mkCls (Ghc.mkClsOcc "UnsafeFromData")] + in Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsDerivingClause Ghc.noAnn (Just strategy) tys + +-- | Build @{-# COMPLETE Ex1, Ex2 #-}@. +makeCompleteDecl :: + Ghc.SrcSpan -> + [Constructor.Constructor] -> + Ghc.LHsDecl Ghc.GhcPs +makeCompleteDecl srcSpan constructors = + let conNames = + fmap + (Ghc.L (Ghc.noAnnSrcSpan srcSpan) . Ghc.mkRdrUnqual . Ghc.rdrNameOcc . Constructor.name) + constructors + in Ghc.noLocA . Ghc.SigD Ghc.noExtField $ +#if __GLASGOW_HASKELL__ >= 910 + Ghc.CompleteMatchSig + ((Ghc.noAnn, Nothing, Ghc.noAnn), Ghc.NoSourceText) + conNames + Nothing +#else + Ghc.CompleteMatchSig + (Ghc.noAnn, Ghc.NoSourceText) + (Ghc.L srcSpan conNames) + Nothing +#endif + +-- | Generate the bidirectional pattern synonym for one constructor. +makePatSynDecl :: + Ghc.SrcSpan -> + Type.Type -> + Constructor.Constructor -> + Integer -> + Ghc.ModuleName -> + Ghc.ModuleName -> + Ghc.Hsc [Ghc.LHsDecl Ghc.GhcPs] +makePatSynDecl srcSpan type_ constructor idx plutusTx plutusTxBuiltins = do + let fields = Constructor.fields constructor + arity = length fields + + vars <- mapM (\_ -> Common.makeRandomVariable srcSpan "x_") fields + dVar <- Common.makeRandomVariable srcSpan "d_" + tagVar <- Common.makeRandomVariable srcSpan "tag_" + argsVar <- Common.makeRandomVariable srcSpan "args_" + viewVars <- mapM (\_ -> Common.makeRandomVariable srcSpan "a_") fields + + let conRdrName = Constructor.name constructor + lConName = Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ Ghc.mkRdrUnqual (Ghc.rdrNameOcc conRdrName) + internalCon = Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ Ghc.mkRdrUnqual (internalConName type_) + + -- The "where" (builder) body: + -- Example_BD (mkConstr idx [toBuiltinData (x0_ :: T0), ...]) + encodeArgs = + fmap + ( \(v, field) -> + Hs.app + srcSpan + (Hs.qualVar srcSpan plutusTx (Ghc.mkVarOcc "toBuiltinData")) + -- type annotation so GHC can resolve ToData instance + (Hs.par srcSpan $ typeAnnotate srcSpan (Field.type_ field) (Hs.var srcSpan v)) + ) + (zip vars fields) + + builderBody = + Hs.app + srcSpan + ( Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsVar Ghc.noExtField internalCon + ) + ( Hs.par srcSpan $ + Hs.app + srcSpan + ( Hs.app + srcSpan + (Hs.qualVar srcSpan plutusTxBuiltins (Ghc.mkVarOcc "mkConstr")) + (intLit srcSpan idx) + ) + (Hs.explicitList srcSpan encodeArgs) + ) + + -- The match (destructor) pattern uses a view pattern: + -- Example_BD (viewFn -> matchPat) + viewFn = makeViewFn srcSpan fields idx dVar tagVar argsVar viewVars plutusTx plutusTxBuiltins + matchPat = makeMatchPat srcSpan arity vars + + matchPat' = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.ConPat + Ghc.noAnn + internalCon + ( Ghc.PrefixCon + [] + [ Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.ViewPat Ghc.noAnn viewFn matchPat + ] + ) + + -- pattern synonym args + patArgs = Ghc.PrefixCon [] vars + + -- The explicit bidirectional direction + builderMatch = + Hs.funMatch + srcSpan + lConName + (fmap (Hs.varPat srcSpan) vars) + (Common.makeGRHSs srcSpan builderBody) + + patSynBind = + Ghc.PatSynBind Ghc.noExtField $ + Ghc.PSB + { Ghc.psb_ext = Ghc.noAnn, + Ghc.psb_id = lConName, + Ghc.psb_args = patArgs, + Ghc.psb_def = matchPat', + Ghc.psb_dir = Ghc.ExplicitBidirectional $ + Hs.mg (Ghc.L srcSpan [builderMatch]) + } + + patSynDecl = Ghc.noLocA $ Ghc.ValD Ghc.noExtField patSynBind + + -- The top-level signature: @pattern Con :: t0 -> ... -> tn -> T a b ...@ + tv n = Hs.tyVar srcSpan (Ghc.L (Ghc.noAnnSrcSpan srcSpan) n) + resultTy = + foldl + (\acc v -> Ghc.L (Ghc.noAnnSrcSpan srcSpan) (Ghc.HsAppTy Ghc.noExtField acc (tv v))) + (tv (Type.name type_)) + (Type.variables type_) + patSynTy = + foldr + (\f acc -> Hs.funTy srcSpan (Ghc.L (Ghc.noAnnSrcSpan srcSpan) (Field.type_ f)) acc) + resultTy + fields + sigDecl = + Ghc.noLocA . Ghc.SigD Ghc.noExtField $ + Ghc.PatSynSig Ghc.noAnn [lConName] $ + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsSig Ghc.noExtField Ghc.mkHsOuterImplicit patSynTy + + pure [sigDecl, patSynDecl] + +-- | Wrap an expression with a type annotation: @(expr :: ty)@. +typeAnnotate :: + Ghc.SrcSpan -> + Ghc.HsType Ghc.GhcPs -> + Ghc.LHsExpr Ghc.GhcPs -> + Ghc.LHsExpr Ghc.GhcPs +typeAnnotate srcSpan ty expr = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.ExprWithTySig + Ghc.noAnn + expr + ( Ghc.HsWC Ghc.noExtField $ + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsSig + Ghc.noExtField + Ghc.mkHsOuterImplicit + (Ghc.L (Ghc.noAnnSrcSpan srcSpan) ty) + ) + +-- | Build the view function for deconstruction. +-- +-- Always checks the constructor tag; returns @Maybe@ so GHC can try the +-- next pattern alternative if the tag doesn't match: +-- +-- @\d_ -> let tag_ = fst (unsafeDataAsConstr d_) +-- args_ = snd (unsafeDataAsConstr d_) +-- in if tag_ == idx then Just \ else Nothing@ +-- +-- @\@ is @()@ for nullary constructors, @(x :: T)@ for arity-1, +-- or a tuple for higher arities. +makeViewFn :: + Ghc.SrcSpan -> + [Field.Field] -> + Integer -> + Ghc.LIdP Ghc.GhcPs -> + Ghc.LIdP Ghc.GhcPs -> + Ghc.LIdP Ghc.GhcPs -> + [Ghc.LIdP Ghc.GhcPs] -> + Ghc.ModuleName -> + Ghc.ModuleName -> + Ghc.LHsExpr Ghc.GhcPs +makeViewFn srcSpan fields idx dVar tagVar argsVar viewVars plutusTx plutusTxBuiltins = + let ptx = Hs.qualVar srcSpan plutusTx + blt = Hs.qualVar srcSpan plutusTxBuiltins + arity = length fields + + -- fst / snd (unsafeDataAsConstr d_) + constrExpr = + Hs.app srcSpan + (blt (Ghc.mkVarOcc "unsafeDataAsConstr")) + (Hs.var srcSpan dVar) + + getFst = Hs.app srcSpan (Hs.var srcSpan (Hs.unqual srcSpan (Ghc.mkVarOcc "fst"))) (Hs.par srcSpan constrExpr) + getSnd = Hs.app srcSpan (Hs.var srcSpan (Hs.unqual srcSpan (Ghc.mkVarOcc "snd"))) (Hs.par srcSpan constrExpr) + + -- helper: 0-arg let binding var = rhs + mkLetFun var rhs = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.FunBind Ghc.noExtField var + (Hs.mg (Ghc.L srcSpan [Hs.funMatch srcSpan var [] (Common.makeGRHSs srcSpan rhs)])) + + tagBind = mkLetFun tagVar getFst + argsBind = mkLetFun argsVar getSnd + + -- (unsafeFromBuiltinData e) :: fieldType + decode fieldType e = + typeAnnotate srcSpan fieldType $ + Hs.app srcSpan + (ptx (Ghc.mkVarOcc "unsafeFromBuiltinData")) + (Hs.par srcSpan e) + + -- The decoded fields, sourced from the explicitly bound @viewVars@. + decoded = case zip viewVars fields of + [(v, f)] -> decode (Field.type_ f) (Hs.var srcSpan v) + vfs -> + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.ExplicitTuple + Ghc.noAnn + (fmap (\(v, f) -> Hs.tupArg (decode (Field.type_ f) (Hs.var srcSpan v))) vfs) + Ghc.Boxed + + justDecoded = + Hs.app srcSpan + (Hs.var srcSpan (Hs.unqual srcSpan (Ghc.mkDataOcc "Just"))) + (Hs.par srcSpan decoded) + + nothingExpr = + Hs.var srcSpan (Hs.unqual srcSpan (Ghc.mkDataOcc "Nothing")) + + -- case args_ of { [a0, ...] -> Just ; _ -> Nothing } + argsListPat = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.ListPat Ghc.noAnn (fmap (Hs.varPat srcSpan) viewVars) + wildPat = Ghc.L (Ghc.noAnnSrcSpan srcSpan) (Ghc.WildPat Ghc.noExtField) + argsCase = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsCase Ghc.noAnn (Hs.var srcSpan argsVar) $ + Hs.mg $ + Ghc.L srcSpan + [ Hs.caseMatch srcSpan [argsListPat] (Common.makeGRHSs srcSpan justDecoded), + Hs.caseMatch srcSpan [wildPat] (Common.makeGRHSs srcSpan nothingExpr) + ] + + -- nullary: @Just ()@; otherwise the explicit-match case above + thenExpr = + if arity == 0 + then + Hs.app srcSpan + (Hs.var srcSpan (Hs.unqual srcSpan (Ghc.mkDataOcc "Just"))) + ( Hs.par srcSpan $ + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.ExplicitTuple Ghc.noAnn [] Ghc.Boxed + ) + else argsCase + + -- tagVar == idx + cond = + Hs.opApp srcSpan + (Hs.var srcSpan tagVar) + (Hs.var srcSpan (Hs.unqual srcSpan (Ghc.mkVarOcc "=="))) + (intLit srcSpan idx) + + ifExpr = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsIf Ghc.noAnn cond thenExpr nothingExpr + + -- omit argsBind for nullary constructors (avoid unused-variable warning) + letBinds = if arity == 0 then [tagBind] else [tagBind, argsBind] + + body = + Hs.letE srcSpan (Hs.valLocalBinds letBinds) ifExpr + in Hs.lam srcSpan . Hs.mg $ + Ghc.L srcSpan + [ Hs.match srcSpan + [Hs.varPat srcSpan dVar] + (Common.makeGRHSs srcSpan body) + ] + +-- | Build the match pattern for the view result. +-- Always wrapped in @Just@: nullary → @Just ()@, arity 1 → @Just x0_@, +-- arity n → @Just (x0_, x1_, ...)@ +makeMatchPat :: + Ghc.SrcSpan -> + Int -> + [Ghc.LIdP Ghc.GhcPs] -> + Ghc.LPat Ghc.GhcPs +makeMatchPat srcSpan arity vars = + let inner = case arity of + 0 -> + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.TuplePat Ghc.noAnn [] Ghc.Boxed + 1 -> + Hs.varPat srcSpan (head vars) + _ -> + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.TuplePat Ghc.noAnn (fmap (Hs.varPat srcSpan) vars) Ghc.Boxed + in Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.ConPat Ghc.noAnn + (Ghc.L (Ghc.noAnnSrcSpan srcSpan) (Ghc.mkRdrUnqual (Ghc.mkDataOcc "Just"))) + (Ghc.PrefixCon [] [inner]) + +-- | Rebuild 'LHsQTyVars' from the type variable 'RdrName's. +mkTyVars :: Ghc.SrcSpan -> [Ghc.IdP Ghc.GhcPs] -> Ghc.LHsQTyVars Ghc.GhcPs +mkTyVars srcSpan vars = + Ghc.HsQTvs Ghc.noExtField (fmap (Hs.userTyVar srcSpan) vars) + +-- | Integer overloaded literal. +intLit :: Ghc.SrcSpan -> Integer -> Ghc.LHsExpr Ghc.GhcPs +intLit s n = + Ghc.L (Ghc.noAnnSrcSpan s) $ +#if __GLASGOW_HASKELL__ >= 910 + Ghc.HsOverLit Ghc.noExtField $ +#else + Ghc.HsOverLit Ghc.noAnn $ +#endif + Ghc.OverLit Ghc.noExtField (Ghc.HsIntegral (Ghc.IL Ghc.NoSourceText False n)) + diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Generator/Common.hs b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Generator/Common.hs new file mode 100644 index 00000000000..15e218913ec --- /dev/null +++ b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Generator/Common.hs @@ -0,0 +1,361 @@ +{-# LANGUAGE CPP #-} + +module PlutusTx.Plugin.Deriving.Generator.Common + ( Generator, + applyAll, + fieldNameOptions, + makeGRHSs, + makeInstanceDeclaration, + makeLHsBind, + makeRandomModule, + makeRandomVariable, + isPlinthVia, + ) +where + +import qualified Control.Monad.IO.Class as IO +import qualified Data.Char as Char +import qualified Data.IORef as IORef +import qualified Data.List as List +import qualified Data.Maybe as Maybe +import qualified Data.Text as Text +import qualified PlutusTx.Plugin.Deriving.Hs as Hs +import qualified PlutusTx.Plugin.Deriving.Hsc as Hsc +import qualified PlutusTx.Plugin.Deriving.Type.Constructor as Constructor +import qualified PlutusTx.Plugin.Deriving.Type.Field as Field +import qualified PlutusTx.Plugin.Deriving.Type.Type as Type +#if __GLASGOW_HASKELL__ < 910 +import qualified GHC.Data.Bag as Ghc +#endif +import qualified GHC.Hs as Ghc +import qualified GHC.Plugins as Ghc +import qualified System.Console.GetOpt as Console +import qualified System.IO.Unsafe as Unsafe +import qualified Text.Printf as Printf + +-- | The 'Bool' indicates whether the original declaration should be dropped +-- (replaced by the generated declarations). Most generators return 'False'. +type Generator = + Ghc.ModuleName -> + Ghc.LIdP Ghc.GhcPs -> + Ghc.LHsQTyVars Ghc.GhcPs -> + [Ghc.LConDecl Ghc.GhcPs] -> + Ghc.SrcSpan -> + Ghc.Hsc + (Bool, [Ghc.LImportDecl Ghc.GhcPs], [Ghc.LHsDecl Ghc.GhcPs]) + +fieldNameOptions :: + Ghc.SrcSpan -> [Console.OptDescr (String -> Ghc.Hsc String)] +fieldNameOptions srcSpan = + [ Console.Option [] ["kebab"] (Console.NoArg $ pure . kebab) "", + Console.Option [] ["camel"] (Console.NoArg $ pure . lower) "", + Console.Option [] ["snake"] (Console.NoArg $ pure . snake) "", + Console.Option [] ["prefix", "strip"] (Console.ReqArg (stripPrefix srcSpan) "PREFIX") "", + Console.Option [] ["suffix"] (Console.ReqArg (stripSuffix srcSpan) "SUFFIX") "", + Console.Option [] ["title"] (Console.NoArg $ pure . upper) "", + Console.Option [] ["rename"] (Console.ReqArg (rename srcSpan) "OLD:NEW") "" + ] + +stripPrefix :: Ghc.SrcSpan -> String -> String -> Ghc.Hsc String +stripPrefix srcSpan prefix s1 = case List.stripPrefix prefix s1 of + Nothing -> + Hsc.throwError srcSpan + . Ghc.text + $ show prefix + <> " is not a prefix of " + <> show s1 + Just s2 -> pure s2 + +stripSuffix :: Ghc.SrcSpan -> String -> String -> Ghc.Hsc String +stripSuffix srcSpan suffix s1 = case Text.stripSuffix (Text.pack suffix) (Text.pack s1) of + Nothing -> + Hsc.throwError srcSpan + . Ghc.text + $ show suffix + <> " is not a suffix of " + <> show s1 + Just s2 -> pure $ Text.unpack s2 + +rename :: Ghc.SrcSpan -> String -> String -> Ghc.Hsc String +rename loc arg str = + case Text.splitOn (Text.singleton ':') $ Text.pack arg of + [old, new] + | not (Text.null old || Text.null new) -> + pure $ + if Text.pack str == old + then Text.unpack new + else str + _ -> Hsc.throwError loc . Ghc.text $ show arg <> " is invalid" + +-- | Applies all the monadic functions in order beginning with some starting +-- value. +applyAll :: Monad m => [a -> m a] -> a -> m a +applyAll fs x = case fs of + [] -> pure x + f : gs -> do + y <- f x + applyAll gs y + +-- | Converts the first character into upper case. +upper :: String -> String +upper = overFirst Char.toUpper + +-- | Converts the first character into lower case. +lower :: String -> String +lower = overFirst Char.toLower + +overFirst :: (a -> a) -> [a] -> [a] +overFirst f xs = case xs of + x : ys -> f x : ys + _ -> xs + +-- | Converts the string into kebab case. +-- +-- >>> kebab "DoReMi" +-- "do-re-mi" +kebab :: String -> String +kebab = camelTo '-' + +-- | Converts the string into snake case. +-- +-- >>> snake "DoReMi" +-- "do_re_mi" +snake :: String -> String +snake = camelTo '_' + +camelTo :: Char -> String -> String +camelTo char = + let go wasUpper string = case string of + "" -> "" + first : rest -> + if Char.isUpper first + then + if wasUpper + then Char.toLower first : go True rest + else char : Char.toLower first : go True rest + else first : go False rest + in go True + +makeLHsType :: + Ghc.SrcSpan -> + Ghc.ModuleName -> + Ghc.OccName -> + Type.Type -> + Ghc.LHsType Ghc.GhcPs +makeLHsType srcSpan moduleName className = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) + . Ghc.HsAppTy + Ghc.noExtField + ( Ghc.L (Ghc.noAnnSrcSpan srcSpan) + . Ghc.HsTyVar Ghc.noAnn Ghc.NotPromoted + . Ghc.L (Ghc.noAnnSrcSpan srcSpan) + $ Ghc.Qual moduleName className + ) + . toLHsType srcSpan + +toLHsType :: Ghc.SrcSpan -> Type.Type -> Ghc.LHsType Ghc.GhcPs +toLHsType srcSpan type_ = + let ext :: Ghc.NoExtField + ext = Ghc.noExtField + + -- A bare type variable, used as a type. Each location wrapper is a + -- fresh expression so its annotation type is inferred per-position + -- (a shared @loc@ binding monomorphises to the wrong annotation). + tv :: Ghc.IdP Ghc.GhcPs -> Ghc.LHsType Ghc.GhcPs + tv n = Hs.tyVar srcSpan (Ghc.L (Ghc.noAnnSrcSpan srcSpan) n) + + initial :: Ghc.LHsType Ghc.GhcPs + initial = tv (Type.name type_) + + combine :: + Ghc.LHsType Ghc.GhcPs -> Ghc.IdP Ghc.GhcPs -> Ghc.LHsType Ghc.GhcPs + combine x v = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) (Ghc.HsAppTy ext x (tv v)) + + bare :: Ghc.LHsType Ghc.GhcPs + bare = List.foldl' combine initial $ Type.variables type_ + in case Type.variables type_ of + [] -> bare + _ -> Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ Ghc.HsParTy Ghc.noAnn bare + +makeHsContext :: + Ghc.SrcSpan -> + Ghc.ModuleName -> + Ghc.OccName -> + Type.Type -> + [Ghc.LHsType Ghc.GhcPs] +makeHsContext srcSpan moduleName className = + fmap + ( Ghc.L (Ghc.noAnnSrcSpan srcSpan) + . Ghc.HsAppTy + Ghc.noExtField + ( Ghc.L (Ghc.noAnnSrcSpan srcSpan) + . Ghc.HsTyVar Ghc.noAnn Ghc.NotPromoted + . Ghc.L (Ghc.noAnnSrcSpan srcSpan) + $ Ghc.Qual moduleName className + ) + . Ghc.L (Ghc.noAnnSrcSpan srcSpan) + . Ghc.HsTyVar Ghc.noAnn Ghc.NotPromoted + . Ghc.L (Ghc.noAnnSrcSpan srcSpan) + . Ghc.Unqual + ) + . List.nub + . Maybe.mapMaybe + ( \field -> case Field.type_ field of + Ghc.HsTyVar _ _ lRdrName -> case Ghc.unLoc lRdrName of + Ghc.Unqual occName | Ghc.isTvOcc occName -> Just occName + _ -> Nothing + _ -> Nothing + ) + . concatMap Constructor.fields + . Type.constructors + +makeHsImplicitBndrs :: + Ghc.SrcSpan -> + Type.Type -> + Ghc.ModuleName -> + Ghc.OccName -> + Ghc.LHsSigType Ghc.GhcPs +makeHsImplicitBndrs srcSpan type_ moduleName className = + let withoutContext = makeLHsType srcSpan moduleName className type_ + context = makeHsContext srcSpan moduleName className type_ + withContext = + if null context + then withoutContext + else + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsQualTy Ghc.noExtField (Ghc.L (Ghc.noAnnSrcSpan srcSpan) context) withoutContext + in Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ Ghc.HsSig Ghc.noExtField Ghc.mkHsOuterImplicit withContext + +-- | Makes a random variable name using the given prefix. +makeRandomVariable :: Ghc.SrcSpan -> String -> Ghc.Hsc (Ghc.LIdP Ghc.GhcPs) +makeRandomVariable srcSpan prefix = do + n <- bumpCounter + pure . Ghc.L (Ghc.noAnnSrcSpan srcSpan) . Ghc.Unqual . Ghc.mkVarOcc $ + Printf.printf + "%s%d" + prefix + n + +-- | Makes a random module name. This will convert any periods to underscores +-- and add a unique suffix. +-- +-- >>> makeRandomModule "Data.Aeson" +-- "Data_Aeson_1" +makeRandomModule :: Ghc.ModuleName -> Ghc.Hsc Ghc.ModuleName +makeRandomModule moduleName = do + n <- bumpCounter + pure . Ghc.mkModuleName $ + Printf.printf + "%s_%d" + (underscoreAll moduleName) + n + +underscoreAll :: Ghc.ModuleName -> String +underscoreAll = fmap underscoreOne . Ghc.moduleNameString + +underscoreOne :: Char -> Char +underscoreOne c = case c of + '.' -> '_' + _ -> c + +makeInstanceDeclaration :: + Ghc.SrcSpan -> + Type.Type -> + Ghc.ModuleName -> + Ghc.OccName -> + [Ghc.LHsBind Ghc.GhcPs] -> + Ghc.LHsDecl Ghc.GhcPs +makeInstanceDeclaration srcSpan type_ moduleName occName lHsBinds = + let hsImplicitBndrs = makeHsImplicitBndrs srcSpan type_ moduleName occName + in makeLHsDecl srcSpan hsImplicitBndrs lHsBinds + +makeLHsDecl :: + Ghc.SrcSpan -> + Ghc.LHsSigType Ghc.GhcPs -> + [Ghc.LHsBind Ghc.GhcPs] -> + Ghc.LHsDecl Ghc.GhcPs +makeLHsDecl srcSpan hsImplicitBndrs lHsBinds = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) + . Ghc.InstD Ghc.noExtField + . Ghc.ClsInstD Ghc.noExtField + $ Ghc.ClsInstDecl +#if __GLASGOW_HASKELL__ >= 910 + (Nothing, Ghc.noAnn, Ghc.NoAnnSortKey) + hsImplicitBndrs + lHsBinds +#else + (Ghc.noAnn, Ghc.NoAnnSortKey) + hsImplicitBndrs + (Ghc.listToBag lHsBinds) +#endif + [] [] [] Nothing + +makeLHsBind :: + Ghc.SrcSpan -> + Ghc.OccName -> + [Ghc.LPat Ghc.GhcPs] -> + Ghc.LHsExpr Ghc.GhcPs -> + Ghc.LHsBind Ghc.GhcPs +makeLHsBind srcSpan occName pats = + Hs.funBind srcSpan occName . makeMatchGroup srcSpan occName pats + +makeMatchGroup :: + Ghc.SrcSpan -> + Ghc.OccName -> + [Ghc.LPat Ghc.GhcPs] -> + Ghc.LHsExpr Ghc.GhcPs -> + Ghc.MatchGroup Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) +makeMatchGroup srcSpan occName lPats hsExpr = + Hs.mg + ( Ghc.L srcSpan + [ Hs.funMatch + srcSpan + (Ghc.L (Ghc.noAnnSrcSpan srcSpan) (Ghc.Unqual occName)) + lPats + (makeGRHSs srcSpan hsExpr) + ] + ) + +makeGRHSs :: + Ghc.SrcSpan -> + Ghc.LHsExpr Ghc.GhcPs -> + Ghc.GRHSs Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) +makeGRHSs srcSpan hsExpr = + Ghc.GRHSs Ghc.emptyComments [Hs.grhs srcSpan hsExpr] $ + Ghc.EmptyLocalBinds Ghc.noExtField + +bumpCounter :: IO.MonadIO m => m Word +bumpCounter = IO.liftIO . IORef.atomicModifyIORef' counterRef $ \n -> (n + 1, n) + +counterRef :: IORef.IORef Word +counterRef = Unsafe.unsafePerformIO $ IORef.newIORef 0 +{-# NOINLINE counterRef #-} + + + +-- | This plugin only fires on specific deriving strategies. In particular it +-- looks for clauses like this: +-- +-- > deriving C via Plinth +-- +-- where @Plinth@ is the sentinel type @data Plinth@. Using a real type (rather +-- than a @Symbol@ string literal) means the name must be in scope, so when the +-- plugin is not loaded GHC reports a clean error instead of a confusing one. +-- +-- This function is responsible for analyzing a deriving strategy to determine +-- if the plugin should fire or not. +isPlinthVia :: + Maybe (Ghc.LDerivStrategy Ghc.GhcPs) -> Bool +isPlinthVia mLDerivStrategy = Maybe.fromMaybe False $ do + lDerivStrategy <- mLDerivStrategy + lHsSigType <- case Ghc.unLoc lDerivStrategy of + Ghc.ViaStrategy (Ghc.XViaStrategyPs _ x) -> Just $ Ghc.unLoc x + _ -> Nothing + lHsType <- case lHsSigType of + Ghc.HsSig _ _ x -> Just x + rdrName <- case Ghc.unLoc lHsType of + Ghc.HsTyVar _ _ x -> Just $ Ghc.unLoc x + _ -> Nothing + pure $ Ghc.occNameString (Ghc.rdrNameOcc rdrName) == "Plinth" + diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Generator/Match.hs b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Generator/Match.hs new file mode 100644 index 00000000000..0a709eb81a0 --- /dev/null +++ b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Generator/Match.hs @@ -0,0 +1,327 @@ +-- The 'head' call below is on a list that is non-empty by construction. +-- (-Wx-partial does not exist before GHC 9.8, hence -Wno-unrecognised-warning-flags) +{-# OPTIONS_GHC -Wno-unrecognised-warning-flags -Wno-x-partial #-} + +module PlutusTx.Plugin.Deriving.Generator.Match + ( generate, + ) +where + +import qualified Data.List as List +import qualified PlutusTx.Plugin.Deriving.Constant.Module as Module +import qualified PlutusTx.Plugin.Deriving.Generator.Common as Common +import qualified PlutusTx.Plugin.Deriving.Hs as Hs +import qualified PlutusTx.Plugin.Deriving.Hsc as Hsc +import qualified PlutusTx.Plugin.Deriving.Type.Constructor as Constructor +import qualified PlutusTx.Plugin.Deriving.Type.Field as Field +import qualified PlutusTx.Plugin.Deriving.Type.Type as Type +import qualified GHC.Hs as Ghc +import qualified GHC.Plugins as Ghc + +-- | Generates a CPS-style destructor function for 'AsData' sum types. +-- +-- Given: +-- +-- > data Example a = Ex1 Integer | Ex2 a a +-- > deriving (AsData, Match) via Plinth +-- +-- Generates: +-- +-- > matchExample :: Example a -> (Integer -> r_N) -> (a -> a -> r_N) -> r_N +-- > matchExample (Example_BD d_) f_0 f_1 = +-- > let tag_ = fst (PlutusTx.Builtins.unsafeDataAsConstr d_) +-- > args_ = snd (PlutusTx.Builtins.unsafeDataAsConstr d_) +-- > in if tag_ == 0 +-- > then f_0 ((PlutusTx.unsafeFromBuiltinData (head args_)) :: Integer) +-- > else f_1 ((PlutusTx.unsafeFromBuiltinData (head args_)) :: a) +-- > ((PlutusTx.unsafeFromBuiltinData (head (tail args_))) :: a) +-- +-- For a single-constructor type, the tag check is omitted entirely: +-- +-- > data Address = Address Credential (Maybe StakingCredential) +-- > deriving (AsData, Match) via Plinth +-- +-- Generates: +-- +-- > matchAddress :: Address -> (Credential -> Maybe StakingCredential -> r_N) -> r_N +-- > matchAddress (Address_BD d_) f_ = +-- > let args_ = snd (PlutusTx.Builtins.unsafeDataAsConstr d_) +-- > in f_ ((PlutusTx.unsafeFromBuiltinData (head args_)) :: Credential) +-- > ((PlutusTx.unsafeFromBuiltinData (head (tail args_))) :: Maybe StakingCredential) +generate :: Ghc.HsDeriving Ghc.GhcPs -> Common.Generator +generate _ _moduleName lIdP lHsQTyVars lConDecls srcSpan = do + type_ <- Type.make lIdP lHsQTyVars lConDecls srcSpan + let constructors = Type.constructors type_ + when (null constructors) $ + Hsc.throwError srcSpan $ Ghc.text "Match requires at least one constructor" + + plutusTx <- Common.makeRandomModule Module.plutusTx + plutusTxBuiltins <- Common.makeRandomModule Module.plutusTxBuiltins + + let lImportDecls = + Hs.importDecls + srcSpan + [ (Module.plutusTx, plutusTx), + (Module.plutusTxBuiltins, plutusTxBuiltins) + ] + + decls <- makeMatchDecls srcSpan type_ constructors plutusTx plutusTxBuiltins + pure (False, lImportDecls, decls) + +when :: Applicative f => Bool -> f () -> f () +when True action = action +when False _ = pure () + +-- | The internal BD constructor name (same convention as 'AsData'). +internalConName :: Type.Type -> Ghc.OccName +internalConName type_ = + Ghc.mkDataOcc $ + Ghc.occNameString (Ghc.rdrNameOcc (Type.name type_)) <> "_BD" + +-- | @"match" <> TypeName@, e.g. @matchExample@. +matchFunOcc :: Type.Type -> Ghc.OccName +matchFunOcc type_ = + Ghc.mkVarOcc $ + "match" <> Ghc.occNameString (Ghc.rdrNameOcc (Type.name type_)) + +makeMatchDecls :: + Ghc.SrcSpan -> + Type.Type -> + [Constructor.Constructor] -> + Ghc.ModuleName -> + Ghc.ModuleName -> + Ghc.Hsc [Ghc.LHsDecl Ghc.GhcPs] +makeMatchDecls srcSpan type_ constructors plutusTx plutusTxBuiltins = do + let funOcc = matchFunOcc type_ + funId = Ghc.L (Ghc.noAnnSrcSpan srcSpan) (Ghc.Unqual funOcc) + internalCon = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.mkRdrUnqual (internalConName type_) + + dVar <- Common.makeRandomVariable srcSpan "d_" + tagVar <- Common.makeRandomVariable srcSpan "tag_" + argsVar <- Common.makeRandomVariable srcSpan "args_" + contVars <- mapM (\_ -> Common.makeRandomVariable srcSpan "f_") constructors + fieldVarss <- + mapM + (mapM (\_ -> Common.makeRandomVariable srcSpan "a_") . Constructor.fields) + constructors + rVar <- Common.makeRandomVariable srcSpan "r_" + + let sigDecl = makeSigDecl srcSpan type_ constructors funId rVar + valDecl = + makeValDecl + srcSpan + constructors + funOcc + dVar + tagVar + argsVar + internalCon + contVars + fieldVarss + plutusTx + plutusTxBuiltins + + pure [sigDecl, valDecl] + +-- | Build the type signature. +-- +-- @matchExample :: Example a -> (Integer -> r_N) -> (a -> a -> r_N) -> r_N@ +makeSigDecl :: + Ghc.SrcSpan -> + Type.Type -> + [Constructor.Constructor] -> + Ghc.LIdP Ghc.GhcPs -> + Ghc.LIdP Ghc.GhcPs -> + Ghc.LHsDecl Ghc.GhcPs +makeSigDecl srcSpan type_ constructors funId rVar = + let loc = Ghc.noAnnSrcSpan srcSpan + + -- @rVar@ is made in the value namespace; as the result /type/ variable + -- it must be in the type-variable namespace, else implicit + -- quantification does not bind it ("not in scope"). + rTyName = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.mkRdrUnqual (Ghc.mkTyVarOcc (Ghc.occNameString (Ghc.rdrNameOcc (Ghc.unLoc rVar)))) + rTy = Ghc.L loc $ Ghc.HsTyVar Ghc.noAnn Ghc.NotPromoted rTyName + + -- A -> B -> ... -> r for a constructor's fields + mkContTy fields = + foldr + (\field acc -> Hs.funTy srcSpan (Ghc.L loc (Field.type_ field)) acc) + rTy + fields + + -- Wrap in parens unless nullary (just r) + mkContTyPar fields = case fields of + [] -> rTy + _ -> Ghc.L loc $ Ghc.HsParTy Ghc.noAnn (mkContTy fields) + + outerTy = mkOuterTy srcSpan type_ + contTys = fmap (mkContTyPar . Constructor.fields) constructors + + -- TypeName vars -> cont0 -> ... -> r + fullTy = + foldr + (\argTy acc -> Hs.funTy srcSpan argTy acc) + rTy + (outerTy : contTys) + in Ghc.noLocA $ Ghc.SigD Ghc.noExtField $ + Ghc.TypeSig Ghc.noAnn [funId] $ + Ghc.HsWC Ghc.noExtField $ + Ghc.L loc $ + Ghc.HsSig Ghc.noExtField Ghc.mkHsOuterImplicit fullTy + +-- | @TypeName a b ...@ as an 'LHsType', parenthesised when there are type vars. +mkOuterTy :: Ghc.SrcSpan -> Type.Type -> Ghc.LHsType Ghc.GhcPs +mkOuterTy srcSpan type_ = + let -- Fresh location wrappers per position (a shared @loc@ monomorphises + -- to the wrong annotation type under GHC ≥ 9.10). + tv n = Hs.tyVar srcSpan (Ghc.L (Ghc.noAnnSrcSpan srcSpan) n) + initial = tv (Type.name type_) + applied = + List.foldl' + ( \acc v -> + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsAppTy Ghc.noExtField acc (tv v) + ) + initial + (Type.variables type_) + in case Type.variables type_ of + [] -> applied + _ -> Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ Ghc.HsParTy Ghc.noAnn applied + +-- | Build the function value declaration. +makeValDecl :: + Ghc.SrcSpan -> + [Constructor.Constructor] -> + Ghc.OccName -> + Ghc.LIdP Ghc.GhcPs -> + Ghc.LIdP Ghc.GhcPs -> + Ghc.LIdP Ghc.GhcPs -> + Ghc.LIdP Ghc.GhcPs -> + [Ghc.LIdP Ghc.GhcPs] -> + [[Ghc.LIdP Ghc.GhcPs]] -> + Ghc.ModuleName -> + Ghc.ModuleName -> + Ghc.LHsDecl Ghc.GhcPs +makeValDecl srcSpan constructors funOcc dVar tagVar argsVar internalCon contVars fieldVarss plutusTx plutusTxBuiltins = + let ptx = Hs.qualVar srcSpan plutusTx + blt = Hs.qualVar srcSpan plutusTxBuiltins + + -- (unsafeFromBuiltinData e) :: FieldType + decode field e = + typeAnnotate srcSpan (Field.type_ field) $ + Hs.app srcSpan + (ptx (Ghc.mkVarOcc "unsafeFromBuiltinData")) + (Hs.par srcSpan e) + + unitExpr = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ Ghc.ExplicitTuple Ghc.noAnn [] Ghc.Boxed + + -- Apply a continuation to the decoded fields, binding them with an + -- explicit list pattern rather than head/tail: + -- @case args_ of { [a0, ...] -> f_ (decode a0) ...; _ -> error () }@. + -- The wildcard branch is unreachable for well-formed Data. + applyFn fVar fields fieldVars = case fields of + [] -> Hs.var srcSpan fVar + _ -> + let applied = + List.foldl' + (Hs.app srcSpan) + (Hs.var srcSpan fVar) + (zipWith (\f v -> decode f (Hs.var srcSpan v)) fields fieldVars) + listPat = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.ListPat Ghc.noAnn (fmap (Hs.varPat srcSpan) fieldVars) + wildPat = Ghc.L (Ghc.noAnnSrcSpan srcSpan) (Ghc.WildPat Ghc.noExtField) + errExpr = + Hs.app srcSpan (blt (Ghc.mkVarOcc "error")) (Hs.par srcSpan unitExpr) + in Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsCase Ghc.noAnn (Hs.var srcSpan argsVar) $ + Hs.mg $ + Ghc.L srcSpan + [ Hs.caseMatch srcSpan [listPat] (Common.makeGRHSs srcSpan applied), + Hs.caseMatch srcSpan [wildPat] (Common.makeGRHSs srcSpan errExpr) + ] + + -- Nested if-else dispatch; last constructor falls through without a tag check + makeDispatch [] = error "Match.makeDispatch: empty list" + makeDispatch [(_, fVar, con, fvs)] = applyFn fVar (Constructor.fields con) fvs + makeDispatch ((idx, fVar, con, fvs) : rest) = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsIf Ghc.noAnn + ( Hs.opApp srcSpan + (Hs.var srcSpan tagVar) + (Hs.var srcSpan (Hs.unqual srcSpan (Ghc.mkVarOcc "=="))) + (intLit srcSpan idx) + ) + (applyFn fVar (Constructor.fields con) fvs) + (makeDispatch rest) + + needsTag = length constructors > 1 + needsArgs = any (not . null . Constructor.fields) constructors + + constrExpr = + Hs.app srcSpan + (blt (Ghc.mkVarOcc "unsafeDataAsConstr")) + (Hs.var srcSpan dVar) + + getFst = Hs.app srcSpan (Hs.var srcSpan (Hs.unqual srcSpan (Ghc.mkVarOcc "fst"))) (Hs.par srcSpan constrExpr) + getSnd = Hs.app srcSpan (Hs.var srcSpan (Hs.unqual srcSpan (Ghc.mkVarOcc "snd"))) (Hs.par srcSpan constrExpr) + + mkLetFun var rhs = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.FunBind Ghc.noExtField var + (Hs.mg (Ghc.L srcSpan [Hs.funMatch srcSpan var [] (Common.makeGRHSs srcSpan rhs)])) + + tagBind = mkLetFun tagVar getFst + argsBind = mkLetFun argsVar getSnd + + letBinds = + (if needsTag then [tagBind] else []) + <> (if needsArgs then [argsBind] else []) + + innerBody = case constructors of + [con] -> applyFn (head contVars) (Constructor.fields con) (head fieldVarss) + _ -> makeDispatch (List.zip4 [0 ..] contVars constructors fieldVarss) + + body = + if null letBinds + then innerBody + else + Hs.letE srcSpan (Hs.valLocalBinds letBinds) innerBody + + -- (TypeName_BD d_) or (TypeName_BD _) when d_ is unused + innerPat = + if null letBinds + then Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ Ghc.WildPat Ghc.noExtField + else Hs.varPat srcSpan dVar + + dPat = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.ConPat Ghc.noAnn internalCon (Ghc.PrefixCon [] [innerPat]) + + allPats = dPat : fmap (Hs.varPat srcSpan) contVars + + in Ghc.noLocA $ Ghc.ValD Ghc.noExtField $ + Ghc.unLoc (Common.makeLHsBind srcSpan funOcc allPats body) + +-- | Wrap an expression with a type annotation: @(expr :: ty)@. +typeAnnotate :: + Ghc.SrcSpan -> + Ghc.HsType Ghc.GhcPs -> + Ghc.LHsExpr Ghc.GhcPs -> + Ghc.LHsExpr Ghc.GhcPs +typeAnnotate srcSpan ty expr = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.ExprWithTySig Ghc.noAnn expr $ + Ghc.HsWC Ghc.noExtField $ + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsSig Ghc.noExtField Ghc.mkHsOuterImplicit + (Ghc.L (Ghc.noAnnSrcSpan srcSpan) ty) + +-- | Integer overloaded literal. +intLit :: Ghc.SrcSpan -> Integer -> Ghc.LHsExpr Ghc.GhcPs +intLit = Hs.intLit diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Generator/Optics.hs b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Generator/Optics.hs new file mode 100644 index 00000000000..6222424eba1 --- /dev/null +++ b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Generator/Optics.hs @@ -0,0 +1,216 @@ +-- The 'head' calls below are on lists that are non-empty by construction +-- (guarded by the surrounding arity matches). +-- (-Wx-partial does not exist before GHC 9.8, hence -Wno-unrecognised-warning-flags) +{-# OPTIONS_GHC -Wno-unrecognised-warning-flags -Wno-x-partial #-} + +module PlutusTx.Plugin.Deriving.Generator.Optics + ( generate, + ) +where + +import qualified Data.List as List +import qualified PlutusTx.Plugin.Deriving.Constant.Module as Module +import qualified PlutusTx.Plugin.Deriving.Generator.Common as Common +import qualified PlutusTx.Plugin.Deriving.Hs as Hs +import qualified PlutusTx.Plugin.Deriving.Type.Constructor as Constructor +import qualified PlutusTx.Plugin.Deriving.Type.Field as Field +import qualified PlutusTx.Plugin.Deriving.Type.Type as Type +import qualified GHC.Hs as Ghc +import qualified GHC.Plugins as Ghc + +-- | For each constructor generates a 'Prism'' binding. For example: +-- +-- > data Shape = Point | Circle Integer | Rectangle Integer Integer +-- > deriving Optics via Plinth +-- +-- Produces: +-- +-- > _Point :: Prism' Shape () +-- > _Circle :: Prism' Shape Integer +-- > _Rectangle :: Prism' Shape (Integer, Integer) +generate :: Ghc.HsDeriving Ghc.GhcPs -> Common.Generator +generate _ _moduleName lIdP lHsQTyVars lConDecls srcSpan = do + type_ <- Type.make lIdP lHsQTyVars lConDecls srcSpan + lens <- Common.makeRandomModule Module.controlLens + let lImportDecls = Hs.importDecls srcSpan [(Module.controlLens, lens)] + decls <- mapM (makePrismDecls srcSpan type_ lens) (Type.constructors type_) + pure (False, lImportDecls, concat decls) + +-- | Generates the signature and binding for one prism. +makePrismDecls :: + Ghc.SrcSpan -> + Type.Type -> + Ghc.ModuleName -> + Constructor.Constructor -> + Ghc.Hsc [Ghc.LHsDecl Ghc.GhcPs] +makePrismDecls srcSpan type_ lens constructor = do + let fields = Constructor.fields constructor + arity = length fields + conOcc = Ghc.rdrNameOcc (Constructor.name constructor) + prismOcc = Ghc.mkVarOcc ("_" <> Ghc.occNameString conOcc) + prismId = Ghc.L (Ghc.noAnnSrcSpan srcSpan) (Ghc.Unqual prismOcc) + lConId = Ghc.L (Ghc.noAnnSrcSpan srcSpan) (Ghc.mkRdrUnqual conOcc) + + vars <- mapM (\_ -> Common.makeRandomVariable srcSpan "x_") fields + scrutVar <- Common.makeRandomVariable srcSpan "x_" + + let -- Focus type: () / T / (T0, T1, ...) + fieldTys = fmap (\f -> Ghc.L (Ghc.noAnnSrcSpan srcSpan) (Field.type_ f)) fields + focusTy = case arity of + 0 -> unitTy srcSpan + 1 -> head fieldTys + _ -> Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ Ghc.HsTupleTy Ghc.noAnn Ghc.HsBoxedOrConstraintTuple fieldTys + + -- Outer type: TypeName a b ... + outerTy = mkOuterTy srcSpan type_ + + -- Prism' OuterTy FocusTy + prismTy = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsAppTy Ghc.noExtField + ( Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsAppTy Ghc.noExtField + ( Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsTyVar Ghc.noAnn Ghc.NotPromoted + (Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ Ghc.Qual lens (Ghc.mkTcOcc "Prism'")) + ) + outerTy + ) + focusTy + + -- _ConName :: Prism' OuterTy FocusTy + sigDecl = + Ghc.noLocA $ Ghc.SigD Ghc.noExtField $ + Ghc.TypeSig Ghc.noAnn [prismId] $ + Ghc.HsWC Ghc.noExtField $ + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsSig Ghc.noExtField Ghc.mkHsOuterImplicit prismTy + + -- _ConName = lens.prism' builder matcher + prismExpr = + Hs.app srcSpan + ( Hs.app srcSpan + (Hs.qualVar srcSpan lens (Ghc.mkVarOcc "prism'")) + (Hs.par srcSpan $ mkBuilder srcSpan lConId vars arity) + ) + (Hs.par srcSpan $ mkMatcher srcSpan lConId vars scrutVar arity) + + valDecl = + Ghc.noLocA $ Ghc.ValD Ghc.noExtField $ + Ghc.FunBind Ghc.noExtField prismId + (Hs.mg (Ghc.L srcSpan [Hs.funMatch srcSpan prismId [] (Common.makeGRHSs srcSpan prismExpr)])) + + pure [sigDecl, valDecl] + +-- | Builder function: converts the focus type back to the sum type. +-- +-- * Nullary: @const ConName@ +-- * Unary: @ConName@ +-- * Multi: @\(x0, x1, ...) -> ConName x0 x1 ...@ +mkBuilder :: + Ghc.SrcSpan -> + Ghc.LIdP Ghc.GhcPs -> + [Ghc.LIdP Ghc.GhcPs] -> + Int -> + Ghc.LHsExpr Ghc.GhcPs +mkBuilder srcSpan lConId vars arity = case arity of + 0 -> + Hs.app srcSpan + (Hs.var srcSpan (Hs.unqual srcSpan (Ghc.mkVarOcc "const"))) + (Hs.var srcSpan lConId) + 1 -> + Hs.var srcSpan lConId + _ -> + let tuplePat = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.TuplePat Ghc.noAnn (fmap (Hs.varPat srcSpan) vars) Ghc.Boxed + body = List.foldl' (Hs.app srcSpan) (Hs.var srcSpan lConId) (fmap (Hs.var srcSpan) vars) + in Hs.lam srcSpan . Hs.mg $ + Ghc.L srcSpan + [Hs.match srcSpan [tuplePat] (Common.makeGRHSs srcSpan body)] + +-- | Matcher function: converts the sum type to @Maybe@ focus. +-- +-- @\x -> case x of { ConName x0 x1 ... -> Just (x0, x1, ...); _ -> Nothing }@ +mkMatcher :: + Ghc.SrcSpan -> + Ghc.LIdP Ghc.GhcPs -> + [Ghc.LIdP Ghc.GhcPs] -> + Ghc.LIdP Ghc.GhcPs -> + Int -> + Ghc.LHsExpr Ghc.GhcPs +mkMatcher srcSpan lConId vars scrutVar arity = + let successResult = case arity of + 0 -> + Hs.app srcSpan + (Hs.var srcSpan (Hs.unqual srcSpan (Ghc.mkDataOcc "Just"))) + (unitExpr srcSpan) + 1 -> + Hs.app srcSpan + (Hs.var srcSpan (Hs.unqual srcSpan (Ghc.mkDataOcc "Just"))) + (Hs.par srcSpan $ Hs.var srcSpan (head vars)) + _ -> + Hs.app srcSpan + (Hs.var srcSpan (Hs.unqual srcSpan (Ghc.mkDataOcc "Just"))) + ( Hs.par srcSpan $ + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.ExplicitTuple Ghc.noAnn (fmap (Hs.tupArg . Hs.var srcSpan) vars) Ghc.Boxed + ) + + conPat = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.ConPat Ghc.noAnn lConId (Ghc.PrefixCon [] (fmap (Hs.varPat srcSpan) vars)) + + wildPat = Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ Ghc.WildPat Ghc.noExtField + + nothingExpr = Hs.var srcSpan (Hs.unqual srcSpan (Ghc.mkDataOcc "Nothing")) + + conMatch = + Hs.caseMatch srcSpan [conPat] (Common.makeGRHSs srcSpan successResult) + + wildMatch = + Hs.caseMatch srcSpan [wildPat] (Common.makeGRHSs srcSpan nothingExpr) + + caseExpr = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsCase Ghc.noAnn + (Hs.var srcSpan scrutVar) + (Hs.mg (Ghc.L srcSpan [conMatch, wildMatch])) + in Hs.lam srcSpan . Hs.mg $ + Ghc.L srcSpan + [ Hs.match srcSpan + [Hs.varPat srcSpan scrutVar] + (Common.makeGRHSs srcSpan caseExpr) + ] + +-- | Build @TypeName a b ...@ as an 'LHsType', parenthesised when there are +-- type variables (so it can be applied to the focus type without ambiguity). +mkOuterTy :: Ghc.SrcSpan -> Type.Type -> Ghc.LHsType Ghc.GhcPs +mkOuterTy srcSpan type_ = + let -- Each location wrapper is a fresh expression so its annotation type is + -- inferred per-position (a shared @loc@ monomorphises to the wrong one). + tv n = Hs.tyVar srcSpan (Ghc.L (Ghc.noAnnSrcSpan srcSpan) n) + initial = tv (Type.name type_) + applied = + List.foldl' + ( \acc v -> + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsAppTy Ghc.noExtField acc (tv v) + ) + initial + (Type.variables type_) + in case Type.variables type_ of + [] -> applied + _ -> Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ Ghc.HsParTy Ghc.noAnn applied + +-- | The unit type @()@ at the type level. +unitTy :: Ghc.SrcSpan -> Ghc.LHsType Ghc.GhcPs +unitTy srcSpan = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.HsTupleTy Ghc.noAnn Ghc.HsBoxedOrConstraintTuple [] + +-- | The unit expression @()@. +unitExpr :: Ghc.SrcSpan -> Ghc.LHsExpr Ghc.GhcPs +unitExpr srcSpan = + Ghc.L (Ghc.noAnnSrcSpan srcSpan) $ + Ghc.ExplicitTuple Ghc.noAnn [] Ghc.Boxed diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Hs.hs b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Hs.hs new file mode 100644 index 00000000000..b57f6992138 --- /dev/null +++ b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Hs.hs @@ -0,0 +1,386 @@ +{-# LANGUAGE CPP #-} + +module PlutusTx.Plugin.Deriving.Hs + ( app, + bindStmt, + doExpr, + explicitList, + explicitTuple, + fieldOcc, + caseMatch, + funBind, + funMatch, + funTy, + grhs, + grhss, + importDecls, + intLit, + lam, + lastStmt, + letE, + lit, + match, + mg, + opApp, + par, + qual, + qualTyVar, + qualVar, + recField, + recFields, + recordCon, + string, + tupArg, + tyVar, + unqual, + unrestrictedArrow, + userTyVar, + valLocalBinds, + var, + varPat, + ) +where + +#if __GLASGOW_HASKELL__ < 912 +import qualified GHC.Data.Bag as Ghc +#endif +import qualified GHC.Hs as Ghc +import qualified GHC.Plugins as Ghc +import qualified GHC.Types.Fixity as Ghc +import qualified GHC.Types.SourceText as Ghc + +app :: + Ghc.SrcSpan -> + Ghc.LHsExpr Ghc.GhcPs -> + Ghc.LHsExpr Ghc.GhcPs -> + Ghc.LHsExpr Ghc.GhcPs +app s f x = Ghc.L (Ghc.noAnnSrcSpan s) $ +#if __GLASGOW_HASKELL__ >= 910 + Ghc.HsApp Ghc.noExtField f x +#else + Ghc.HsApp Ghc.noAnn f x +#endif + +bindStmt :: + Ghc.SrcSpan -> + Ghc.LPat Ghc.GhcPs -> + Ghc.LHsExpr Ghc.GhcPs -> + Ghc.LStmt Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) +bindStmt s p e = + Ghc.L (Ghc.noAnnSrcSpan s) $ Ghc.BindStmt Ghc.noAnn p e + +doExpr :: Ghc.SrcSpan -> [Ghc.ExprLStmt Ghc.GhcPs] -> Ghc.LHsExpr Ghc.GhcPs +doExpr s stmts = + Ghc.L (Ghc.noAnnSrcSpan s) $ + Ghc.HsDo Ghc.noAnn (Ghc.DoExpr Nothing) (Ghc.L (Ghc.noAnnSrcSpan s) stmts) + +explicitList :: + Ghc.SrcSpan -> [Ghc.LHsExpr Ghc.GhcPs] -> Ghc.LHsExpr Ghc.GhcPs +explicitList s xs = Ghc.L (Ghc.noAnnSrcSpan s) $ Ghc.ExplicitList Ghc.noAnn xs + +explicitTuple :: + Ghc.SrcSpan -> [Ghc.HsTupArg Ghc.GhcPs] -> Ghc.LHsExpr Ghc.GhcPs +explicitTuple s xs = Ghc.L (Ghc.noAnnSrcSpan s) $ Ghc.ExplicitTuple Ghc.noAnn xs Ghc.Boxed + +fieldOcc :: Ghc.SrcSpan -> Ghc.RdrName -> Ghc.LFieldOcc Ghc.GhcPs +fieldOcc s r = Ghc.L (Ghc.noAnnSrcSpan s) $ Ghc.FieldOcc + { Ghc.foExt = Ghc.noExtField + , Ghc.foLabel = Ghc.L (Ghc.noAnnSrcSpan s) r + } + +funBind :: + Ghc.SrcSpan -> + Ghc.OccName -> + Ghc.MatchGroup Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) -> + Ghc.LHsBind Ghc.GhcPs +funBind s f g = + Ghc.L (Ghc.noAnnSrcSpan s) $ Ghc.FunBind Ghc.noExtField (unqual s f) g + +grhs :: + Ghc.SrcSpan -> + Ghc.LHsExpr Ghc.GhcPs -> + Ghc.LGRHS Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) +grhs s e = Ghc.L (Ghc.noAnnSrcSpan s) $ Ghc.GRHS Ghc.noAnn [] e + +grhss :: + Ghc.SrcSpan -> + [Ghc.LGRHS Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs)] -> + Ghc.GRHSs Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) +grhss _ xs = + Ghc.GRHSs Ghc.emptyComments xs $ Ghc.EmptyLocalBinds Ghc.noExtField + +importDecl :: + Ghc.SrcSpan -> + Ghc.ModuleName -> + Ghc.ModuleName -> + Ghc.LImportDecl Ghc.GhcPs +importDecl s m n = + Ghc.L (Ghc.noAnnSrcSpan s) $ + Ghc.ImportDecl + { Ghc.ideclExt = Ghc.XImportDeclPass + { Ghc.ideclAnn = Ghc.noAnn + , Ghc.ideclSourceText = Ghc.NoSourceText + , Ghc.ideclImplicit = False + } + , Ghc.ideclName = Ghc.L (Ghc.noAnnSrcSpan s) m + , Ghc.ideclPkgQual = Ghc.NoRawPkgQual + , Ghc.ideclSource = Ghc.NotBoot + , Ghc.ideclSafe = False + , Ghc.ideclQualified = Ghc.QualifiedPre + , Ghc.ideclAs = Just $ Ghc.L (Ghc.noAnnSrcSpan s) n + , Ghc.ideclImportList = Nothing + } + +importDecls :: + Ghc.SrcSpan -> + [(Ghc.ModuleName, Ghc.ModuleName)] -> + [Ghc.LImportDecl Ghc.GhcPs] +importDecls = fmap . uncurry . importDecl + +lam :: + Ghc.SrcSpan -> + Ghc.MatchGroup Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) -> + Ghc.LHsExpr Ghc.GhcPs +lam s mg_ = Ghc.L (Ghc.noAnnSrcSpan s) $ +#if __GLASGOW_HASKELL__ >= 910 + Ghc.HsLam Ghc.noAnn Ghc.LamSingle mg_ +#else + Ghc.HsLam Ghc.noExtField mg_ +#endif + +lastStmt :: + Ghc.SrcSpan -> + Ghc.LHsExpr Ghc.GhcPs -> + Ghc.LStmt Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) +lastStmt s e = Ghc.L (Ghc.noAnnSrcSpan s) $ Ghc.LastStmt Ghc.noExtField e Nothing noSyntaxExpr + +lit :: Ghc.SrcSpan -> Ghc.HsLit Ghc.GhcPs -> Ghc.LHsExpr Ghc.GhcPs +lit s l = Ghc.L (Ghc.noAnnSrcSpan s) $ +#if __GLASGOW_HASKELL__ >= 910 + Ghc.HsLit Ghc.noExtField l +#else + Ghc.HsLit Ghc.noAnn l +#endif + +noSyntaxExpr :: Ghc.SyntaxExpr Ghc.GhcPs +noSyntaxExpr = Ghc.noSyntaxExpr + +-- | Build a lambda match. The match context is always a (single) lambda, +-- so it is baked in rather than passed by the caller; this also avoids a +-- CPP-divergent context type in the signature. +match :: + Ghc.SrcSpan -> + [Ghc.LPat Ghc.GhcPs] -> + Ghc.GRHSs Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) -> + Ghc.LMatch Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) +match s ps g = + Ghc.L (Ghc.noAnnSrcSpan s) $ +#if __GLASGOW_HASKELL__ >= 910 + Ghc.Match Ghc.noExtField (Ghc.LamAlt Ghc.LamSingle) (Ghc.L (Ghc.noAnnSrcSpan s) ps) g +#else + Ghc.Match Ghc.noAnn Ghc.LambdaExpr ps g +#endif + +-- | A prefix-function ('FunRhs') match with located patterns. @FunRhs@ gained +-- an annotation field and @m_pats@ became located in GHC 9.10. +funMatch :: + Ghc.SrcSpan -> + Ghc.LIdP Ghc.GhcPs -> + [Ghc.LPat Ghc.GhcPs] -> + Ghc.GRHSs Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) -> + Ghc.LMatch Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) +funMatch s v ps g = + Ghc.L (Ghc.noAnnSrcSpan s) $ +#if __GLASGOW_HASKELL__ >= 910 + Ghc.Match Ghc.noExtField (Ghc.FunRhs v Ghc.Prefix Ghc.NoSrcStrict Ghc.noAnn) (Ghc.L (Ghc.noAnnSrcSpan s) ps) g +#else + Ghc.Match Ghc.noAnn (Ghc.FunRhs v Ghc.Prefix Ghc.NoSrcStrict) ps g +#endif + +-- | A @case@ alternative ('CaseAlt') match with located patterns. +caseMatch :: + Ghc.SrcSpan -> + [Ghc.LPat Ghc.GhcPs] -> + Ghc.GRHSs Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) -> + Ghc.LMatch Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) +caseMatch s ps g = + Ghc.L (Ghc.noAnnSrcSpan s) $ +#if __GLASGOW_HASKELL__ >= 910 + Ghc.Match Ghc.noExtField Ghc.CaseAlt (Ghc.L (Ghc.noAnnSrcSpan s) ps) g +#else + Ghc.Match Ghc.noAnn Ghc.CaseAlt ps g +#endif + +-- | A @let ... in ...@ expression. The @let@/@in@ tokens moved into the +-- extension field in GHC 9.10. +letE :: + Ghc.SrcSpan -> + Ghc.HsLocalBinds Ghc.GhcPs -> + Ghc.LHsExpr Ghc.GhcPs -> + Ghc.LHsExpr Ghc.GhcPs +letE s binds body = + Ghc.L (Ghc.noAnnSrcSpan s) $ +#if __GLASGOW_HASKELL__ >= 910 + Ghc.HsLet Ghc.noAnn binds body +#else + Ghc.HsLet Ghc.noAnn Ghc.noHsTok binds Ghc.noHsTok body +#endif + +-- | A function type @a -> b@ (unrestricted arrow). The @HsFunTy@ extension +-- field became 'Ghc.NoExtField' in GHC 9.10. +funTy :: + Ghc.SrcSpan -> + Ghc.LHsType Ghc.GhcPs -> + Ghc.LHsType Ghc.GhcPs -> + Ghc.LHsType Ghc.GhcPs +funTy s a b = + Ghc.L (Ghc.noAnnSrcSpan s) $ +#if __GLASGOW_HASKELL__ >= 910 + Ghc.HsFunTy Ghc.noExtField unrestrictedArrow a b +#else + Ghc.HsFunTy Ghc.noAnn unrestrictedArrow a b +#endif + +-- | An integer overloaded literal. The @HsOverLit@ extension field became +-- 'Ghc.NoExtField' in GHC 9.10. +intLit :: Ghc.SrcSpan -> Integer -> Ghc.LHsExpr Ghc.GhcPs +intLit s n = + Ghc.L (Ghc.noAnnSrcSpan s) $ +#if __GLASGOW_HASKELL__ >= 910 + Ghc.HsOverLit Ghc.noExtField $ +#else + Ghc.HsOverLit Ghc.noAnn $ +#endif + Ghc.OverLit Ghc.noExtField (Ghc.HsIntegral (Ghc.IL Ghc.NoSourceText False n)) + +-- | The unrestricted function arrow @->@. Its token representation moved +-- into an annotation in GHC 9.10. +unrestrictedArrow :: Ghc.HsArrow Ghc.GhcPs +#if __GLASGOW_HASKELL__ >= 910 +unrestrictedArrow = Ghc.HsUnrestrictedArrow Ghc.noAnn +#else +unrestrictedArrow = Ghc.HsUnrestrictedArrow Ghc.noHsUniTok +#endif + +-- | Value local-bindings from a list of binds. @LHsBinds@ became a plain +-- list (was a @Bag@) in GHC 9.12. +valLocalBinds :: [Ghc.LHsBind Ghc.GhcPs] -> Ghc.HsLocalBinds Ghc.GhcPs +valLocalBinds binds = + Ghc.HsValBinds Ghc.noAnn $ + Ghc.ValBinds + Ghc.NoAnnSortKey +#if __GLASGOW_HASKELL__ >= 912 + binds +#else + (Ghc.listToBag binds) +#endif + [] + +-- | A user type-variable binder with no kind annotation. The binder layout +-- changed to @HsTvb@/@HsBndrVar@ in GHC 9.10, and the binder visibility flag +-- (@HsBndrVis@) replaced the unit flag for these binders. +userTyVar :: + Ghc.SrcSpan -> + Ghc.IdP Ghc.GhcPs -> +#if __GLASGOW_HASKELL__ >= 910 + Ghc.LHsTyVarBndr (Ghc.HsBndrVis Ghc.GhcPs) Ghc.GhcPs +#else + Ghc.LHsTyVarBndr () Ghc.GhcPs +#endif +userTyVar s v = + Ghc.L (Ghc.noAnnSrcSpan s) $ +#if __GLASGOW_HASKELL__ >= 910 + Ghc.HsTvb Ghc.noAnn (Ghc.HsBndrRequired Ghc.noExtField) (Ghc.HsBndrVar Ghc.noExtField (Ghc.L (Ghc.noAnnSrcSpan s) v)) (Ghc.HsBndrNoKind Ghc.noExtField) +#else + Ghc.UserTyVar Ghc.noAnn () (Ghc.L (Ghc.noAnnSrcSpan s) v) +#endif + +mg :: + Ghc.Located [Ghc.LMatch Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs)] -> + Ghc.MatchGroup Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) +mg ms = + Ghc.MG +#if __GLASGOW_HASKELL__ >= 910 + (Ghc.Generated Ghc.OtherExpansion Ghc.SkipPmc) +#else + Ghc.Generated +#endif + (Ghc.noLocA (Ghc.unLoc ms)) + +opApp :: + Ghc.SrcSpan -> + Ghc.LHsExpr Ghc.GhcPs -> + Ghc.LHsExpr Ghc.GhcPs -> + Ghc.LHsExpr Ghc.GhcPs -> + Ghc.LHsExpr Ghc.GhcPs +opApp s l o r = Ghc.L (Ghc.noAnnSrcSpan s) $ +#if __GLASGOW_HASKELL__ >= 910 + Ghc.OpApp Ghc.noExtField l o r +#else + Ghc.OpApp Ghc.noAnn l o r +#endif + +par :: Ghc.SrcSpan -> Ghc.LHsExpr Ghc.GhcPs -> Ghc.LHsExpr Ghc.GhcPs +par s e = Ghc.L (Ghc.noAnnSrcSpan s) $ +#if __GLASGOW_HASKELL__ >= 910 + Ghc.HsPar Ghc.noAnn e +#else + Ghc.HsPar Ghc.noAnn Ghc.noHsTok e Ghc.noHsTok +#endif + +qual :: Ghc.SrcSpan -> Ghc.ModuleName -> Ghc.OccName -> Ghc.LIdP Ghc.GhcPs +qual s m n = Ghc.L (Ghc.noAnnSrcSpan s) $ Ghc.mkRdrQual m n + +qualTyVar :: + Ghc.SrcSpan -> Ghc.ModuleName -> Ghc.OccName -> Ghc.LHsType Ghc.GhcPs +qualTyVar s m = tyVar s . qual s m + +qualVar :: + Ghc.SrcSpan -> Ghc.ModuleName -> Ghc.OccName -> Ghc.LHsExpr Ghc.GhcPs +qualVar s m = var s . qual s m + +recFields :: + [Ghc.LHsRecField Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs)] -> + Ghc.HsRecFields Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) +recFields fs = +#if __GLASGOW_HASKELL__ >= 910 + Ghc.HsRecFields Ghc.noExtField fs Nothing +#else + Ghc.HsRecFields fs Nothing +#endif + +recField :: + Ghc.SrcSpan -> + Ghc.LFieldOcc Ghc.GhcPs -> + Ghc.LHsExpr Ghc.GhcPs -> + Ghc.LHsRecField Ghc.GhcPs (Ghc.LHsExpr Ghc.GhcPs) +recField s f e = Ghc.L (Ghc.noAnnSrcSpan s) $ Ghc.HsFieldBind Ghc.noAnn f e False + +recordCon :: + Ghc.SrcSpan -> + Ghc.LIdP Ghc.GhcPs -> + Ghc.HsRecordBinds Ghc.GhcPs -> + Ghc.LHsExpr Ghc.GhcPs +recordCon s c fs = Ghc.L (Ghc.noAnnSrcSpan s) $ Ghc.RecordCon Ghc.noAnn c fs + +string :: String -> Ghc.HsLit Ghc.GhcPs +string = Ghc.HsString Ghc.NoSourceText . Ghc.mkFastString + +tupArg :: Ghc.LHsExpr Ghc.GhcPs -> Ghc.HsTupArg Ghc.GhcPs +#if __GLASGOW_HASKELL__ >= 910 +tupArg = Ghc.Present Ghc.noExtField +#else +tupArg = Ghc.Present Ghc.noAnn +#endif + +tyVar :: Ghc.SrcSpan -> Ghc.LIdP Ghc.GhcPs -> Ghc.LHsType Ghc.GhcPs +tyVar s x = Ghc.L (Ghc.noAnnSrcSpan s) $ Ghc.HsTyVar Ghc.noAnn Ghc.NotPromoted x + +unqual :: Ghc.SrcSpan -> Ghc.OccName -> Ghc.LIdP Ghc.GhcPs +unqual s n = Ghc.L (Ghc.noAnnSrcSpan s) $ Ghc.mkRdrUnqual n + +var :: Ghc.SrcSpan -> Ghc.LIdP Ghc.GhcPs -> Ghc.LHsExpr Ghc.GhcPs +var s x = Ghc.L (Ghc.noAnnSrcSpan s) $ Ghc.HsVar Ghc.noExtField x + +varPat :: Ghc.SrcSpan -> Ghc.LIdP Ghc.GhcPs -> Ghc.LPat Ghc.GhcPs +varPat s x = Ghc.L (Ghc.noAnnSrcSpan s) $ Ghc.VarPat Ghc.noExtField x diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Hsc.hs b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Hsc.hs new file mode 100644 index 00000000000..f2e5f8ae027 --- /dev/null +++ b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Hsc.hs @@ -0,0 +1,48 @@ +{-# LANGUAGE CPP #-} + +module PlutusTx.Plugin.Deriving.Hsc + ( addWarning, + throwError, + ) +where + +import qualified Control.Monad.IO.Class as IO +import qualified GHC as Ghc +import qualified GHC.Data.Bag as Ghc +import qualified GHC.Driver.Config.Diagnostic as Ghc +import qualified GHC.Driver.Errors.Types as Ghc +import qualified GHC.Plugins as Ghc +import qualified GHC.Types.Error as Ghc +import qualified GHC.Utils.Error as Ghc +import qualified GHC.Utils.Logger as Ghc + +-- | Adds a warning +addWarning :: Ghc.SrcSpan -> Ghc.SDoc -> Ghc.Hsc () +addWarning srcSpan msgDoc = do + logger <- Ghc.getLogger + IO.liftIO $ Ghc.logMsg + logger + Ghc.MCOutput + srcSpan + msgDoc + +-- | Throws an error +throwError :: Ghc.SrcSpan -> Ghc.SDoc -> Ghc.Hsc a +throwError srcSpan msgDoc = do + dynFlags <- Ghc.getDynFlags + let diagOpts = Ghc.initDiagOpts dynFlags + -- 1. Create the plain diagnostic + innerDiag = Ghc.mkPlainDiagnostic Ghc.WarningWithoutFlag [] msgDoc + + -- 2. Use the 'GhcUnknownMessage' wrapper with a 'Simple' constructor + -- This bypasses the need for phase-specific types like DsMessage. +#if __GLASGOW_HASKELL__ >= 910 + diagnostic = Ghc.GhcUnknownMessage (Ghc.UnknownDiagnostic (const Ghc.NoDiagnosticOpts) innerDiag) +#else + diagnostic = Ghc.GhcUnknownMessage (Ghc.UnknownDiagnostic innerDiag) +#endif + + -- 3. Create the envelope + msg = Ghc.mkPlainMsgEnvelope diagOpts srcSpan diagnostic + + Ghc.throwErrors $ Ghc.mkMessages (Ghc.unitBag msg) diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Type/Constructor.hs b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Type/Constructor.hs new file mode 100644 index 00000000000..6866f36d938 --- /dev/null +++ b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Type/Constructor.hs @@ -0,0 +1,39 @@ +module PlutusTx.Plugin.Deriving.Type.Constructor + ( Constructor (..), + make, + ) +where + +import qualified Control.Monad as Monad +import qualified Data.List as List +import qualified PlutusTx.Plugin.Deriving.Hsc as Hsc +import qualified PlutusTx.Plugin.Deriving.Type.Field as Field +import qualified GHC.Hs as Ghc +import qualified GHC.Plugins as Ghc + +data Constructor = Constructor + { name :: Ghc.IdP Ghc.GhcPs, + fields :: [Field.Field] + } + +make :: Ghc.SrcSpan -> Ghc.LConDecl Ghc.GhcPs -> Ghc.Hsc Constructor +make srcSpan lConDecl = do + (lIdP, hsConDeclDetails) <- case Ghc.unLoc lConDecl of + Ghc.ConDeclH98 _ x _ _ _ y _ -> pure (x, y) + _ -> Hsc.throwError srcSpan $ Ghc.text "unsupported LConDecl" + theFields <- case hsConDeclDetails of + Ghc.RecCon lConDeclFields -> + fmap concat . Monad.forM (Ghc.unLoc lConDeclFields) $ \lConDeclField -> do + (lFieldOccs, lHsType) <- case Ghc.unLoc lConDeclField of + Ghc.ConDeclField _ x y _ -> pure (x, y) + mapM (Field.make srcSpan lHsType) lFieldOccs + Ghc.PrefixCon _ scaledTypes -> + pure $ List.zipWith + (\i (Ghc.HsScaled _ lHsType) -> Field.Field + { Field.name = Ghc.mkVarOcc ("_field" <> show (i :: Int)), + Field.type_ = Ghc.unLoc lHsType + }) + [0 ..] + scaledTypes + _ -> Hsc.throwError srcSpan $ Ghc.text "unsupported HsConDeclDetails" + pure Constructor {name = Ghc.unLoc lIdP, fields = theFields} diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Type/Field.hs b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Type/Field.hs new file mode 100644 index 00000000000..5828a7501f7 --- /dev/null +++ b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Type/Field.hs @@ -0,0 +1,37 @@ +module PlutusTx.Plugin.Deriving.Type.Field + ( Field (..), + make, + isOptional, + ) +where + +import qualified PlutusTx.Plugin.Deriving.Hsc as Hsc +import qualified GHC.Hs as Ghc +import qualified GHC.Plugins as Ghc + +data Field = Field + { name :: Ghc.OccName, + type_ :: Ghc.HsType Ghc.GhcPs + } + +make :: + Ghc.SrcSpan -> + Ghc.LHsType Ghc.GhcPs -> + Ghc.LFieldOcc Ghc.GhcPs -> + Ghc.Hsc Field +make srcSpan lHsType lFieldOcc = do + lRdrName <- case Ghc.unLoc lFieldOcc of + Ghc.FieldOcc _ x -> pure x + occName <- case Ghc.unLoc lRdrName of + Ghc.Unqual x -> pure x + _ -> Hsc.throwError srcSpan $ Ghc.text "unsupported RdrName" + pure Field {name = occName, type_ = Ghc.unLoc lHsType} + +isOptional :: Field -> Bool +isOptional field = case type_ field of + Ghc.HsAppTy _ lHsType _ -> case Ghc.unLoc lHsType of + Ghc.HsTyVar _ _ lIdP -> case Ghc.unLoc lIdP of + Ghc.Unqual occName -> Ghc.occNameString occName == "Maybe" + _ -> False + _ -> False + _ -> False diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Type/Type.hs b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Type/Type.hs new file mode 100644 index 00000000000..966a8e8b71a --- /dev/null +++ b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Type/Type.hs @@ -0,0 +1,57 @@ +{-# LANGUAGE CPP #-} + +module PlutusTx.Plugin.Deriving.Type.Type + ( Type (..), + make, + qualifiedName, + ) +where + +import qualified Control.Monad as Monad +#if __GLASGOW_HASKELL__ >= 910 +import qualified PlutusTx.Plugin.Deriving.Hsc as Hsc +#endif +import qualified PlutusTx.Plugin.Deriving.Type.Constructor as Constructor +import qualified GHC.Hs as Ghc +import qualified GHC.Plugins as Ghc + +data Type = Type + { name :: Ghc.IdP Ghc.GhcPs, + variables :: [Ghc.IdP Ghc.GhcPs], + constructors :: [Constructor.Constructor] + } + +make :: + Ghc.LIdP Ghc.GhcPs -> + Ghc.LHsQTyVars Ghc.GhcPs -> + [Ghc.LConDecl Ghc.GhcPs] -> + Ghc.SrcSpan -> + Ghc.Hsc Type +make lIdP lHsQTyVars lConDecls srcSpan = do + lHsTyVarBndrs <- case lHsQTyVars of + Ghc.HsQTvs _ hsq_explicit -> pure hsq_explicit + theVariables <- Monad.forM lHsTyVarBndrs $ \lHsTyVarBndr -> + case Ghc.unLoc lHsTyVarBndr of +#if __GLASGOW_HASKELL__ >= 910 + Ghc.HsTvb _ _ (Ghc.HsBndrVar _ (Ghc.L _ var)) _ -> pure var + -- HsBndrWildCard and the XTyVarBndr extension constructor + _ -> Hsc.throwError srcSpan $ Ghc.text "unknown type variable binder" +#else + Ghc.UserTyVar _ _ (Ghc.L _ var) -> pure var + Ghc.KindedTyVar _ _ (Ghc.L _ var) _ -> pure var +#endif + theConstructors <- mapM (Constructor.make srcSpan) lConDecls + pure + Type + { name = Ghc.unLoc lIdP, + variables = theVariables, + constructors = theConstructors + } + +qualifiedName :: Ghc.ModuleName -> Type -> String +qualifiedName moduleName type_ = + mconcat + [ Ghc.moduleNameString moduleName, + ".", + Ghc.occNameString . Ghc.rdrNameOcc $ name type_ + ] diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Via.hs b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Via.hs new file mode 100644 index 00000000000..31776a55e6e --- /dev/null +++ b/plutus-tx-plugin/src/PlutusTx/Plugin/Deriving/Via.hs @@ -0,0 +1,15 @@ +{-# LANGUAGE EmptyDataDecls #-} + +-- | The @DerivingVia@ sentinel type recognised by the Plinth deriving plugin. +module PlutusTx.Plugin.Deriving.Via (Plinth) where + +-- | Used as a @DerivingVia@ target to activate the deriving plugin, e.g. +-- +-- > data Foo = Foo Integer Integer +-- > deriving AsData via Plinth +-- +-- When the plugin is active the deriving clause is rewritten away at parse +-- time, so @Plinth@ never actually has to be in scope. Defining it as a real +-- type means that when the plugin is /not/ loaded GHC reports a clean error +-- instead of a confusing one. +data Plinth diff --git a/plutus-tx-plugin/test-frontend-plugin/AsData/Spec.hs b/plutus-tx-plugin/test-frontend-plugin/AsData/Spec.hs new file mode 100644 index 00000000000..c797ddeb581 --- /dev/null +++ b/plutus-tx-plugin/test-frontend-plugin/AsData/Spec.hs @@ -0,0 +1,50 @@ +{-# OPTIONS_GHC -fplugin Plinth.Plugin #-} +{-# OPTIONS_GHC -fplugin-opt Plinth.Plugin:context-level=0 #-} +{-# OPTIONS_GHC -fplugin-opt Plinth.Plugin:datatypes=BuiltinCasing #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} + +{-| Activation test for the Plinth @deriving via@ pass. + +This module is compiled with the Plinth plugin (@-fplugin Plinth.Plugin@), into +which the deriving pass is wired. The plugin must rewrite the +@deriving AsData via Plinth@ clause below into a @BuiltinData@-backed newtype +plus bidirectional pattern synonyms @Circle@/@Rectangle@ (and inject the +@PlutusTx.*@ imports the generated code uses). + +If the deriving pass does /not/ fire, the clause survives to the renamer — +where @AsData@ is not a real class and the @Circle@ synonym does not exist — so +this module fails to compile. Thus successful compilation /is/ the test. +-} +module AsData.Spec (tests) where + +import Test.Tasty (TestTree, testGroup) +import Test.Tasty.HUnit (testCase, (@?=)) + +data Shape + = Circle Integer + | Rectangle Integer Integer + deriving AsData via Plinth + +-- | These reference both plugin-generated pattern synonyms (so neither is +-- flagged unused under @-Werror@). The trailing wildcard makes each match +-- total and non-overlapping regardless of the generated @COMPLETE@ pragma, +-- so they do not trip @-Werror@ on incomplete/overlapping patterns either. +isCircle :: Shape -> Bool +isCircle (Circle _) = True +isCircle _ = False + +isRectangle :: Shape -> Bool +isRectangle (Rectangle _ _) = True +isRectangle _ = False + +tests :: TestTree +tests = + testGroup + "AsData via Plinth" + [ testCase "plugin fires; generated synonyms construct and match" $ do + isCircle (Circle 1) @?= True + isRectangle (Circle 1) @?= False + isRectangle (Rectangle 2 3) @?= True + ] diff --git a/plutus-tx-plugin/test-frontend-plugin/Match/Spec.hs b/plutus-tx-plugin/test-frontend-plugin/Match/Spec.hs new file mode 100644 index 00000000000..86ac2d04ca8 --- /dev/null +++ b/plutus-tx-plugin/test-frontend-plugin/Match/Spec.hs @@ -0,0 +1,40 @@ +{-# OPTIONS_GHC -fplugin Plinth.Plugin #-} +{-# OPTIONS_GHC -fplugin-opt Plinth.Plugin:context-level=0 #-} +{-# OPTIONS_GHC -fplugin-opt Plinth.Plugin:datatypes=BuiltinCasing #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} + +{-| Activation test for the @Match@ generator (the CPS destructor). + +Compiled with the Plinth plugin, @deriving (AsData, Match) via Plinth@ must +produce the @Circle@/@Rectangle@ pattern synonyms /and/ a destructor + +> matchShape :: Shape -> (Integer -> r) -> (Integer -> Integer -> r) -> r + +'firstField' constructs values via the AsData synonyms and destructures them +via @matchShape@, exercising both continuations and the (now head/tail-free) +field decoding. If the plugin does not fire, @matchShape@ is undefined and this +module fails to compile. +-} +module Match.Spec (tests) where + +import Test.Tasty (TestTree, testGroup) +import Test.Tasty.HUnit (testCase, (@?=)) + +data Shape + = Circle Integer + | Rectangle Integer Integer + deriving (AsData, Match) via Plinth + +firstField :: Shape -> Integer +firstField s = matchShape s (\r -> r) (\w _ -> w) + +tests :: TestTree +tests = + testGroup + "Match via Plinth" + [ testCase "matchShape dispatches on the tag and decodes fields" $ do + firstField (Circle 7) @?= 7 + firstField (Rectangle 3 5) @?= 3 + ] diff --git a/plutus-tx-plugin/test-frontend-plugin/Spec.hs b/plutus-tx-plugin/test-frontend-plugin/Spec.hs index fe60ec3ad25..fdece905859 100644 --- a/plutus-tx-plugin/test-frontend-plugin/Spec.hs +++ b/plutus-tx-plugin/test-frontend-plugin/Spec.hs @@ -1,10 +1,12 @@ module Main (main) where +import AsData.Spec qualified as AsData import Inlineable.Spec qualified as Inlineable +import Match.Spec qualified as Match import NoStrict.Spec qualified as NoStrict import Strict.Spec qualified as Strict -import Test.Tasty (TestTree, defaultMain) +import Test.Tasty (TestTree, defaultMain, testGroup) import Test.Tasty.Extras (runTestNested) main :: IO () @@ -12,9 +14,14 @@ main = defaultMain tests tests :: TestTree tests = - runTestNested - ["test-frontend-plugin"] - [ Strict.tests - , NoStrict.tests - , Inlineable.tests + testGroup + "frontend-plugin-tests" + [ runTestNested + ["test-frontend-plugin"] + [ Strict.tests + , NoStrict.tests + , Inlineable.tests + ] + , AsData.tests + , Match.tests ]