Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new step Signature that will format function signatures #350

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions lib/Language/Haskell/Stylish/Editor.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ module Language.Haskell.Stylish.Editor
, delete
, deleteLine
, insert
, noop
) where


Expand Down Expand Up @@ -84,6 +85,9 @@ applyChanges changes0
change :: Block a -> ([a] -> [a]) -> Change a
change = Change

--------------------------------------------------------------------------------
noop :: Block a -> Change a
noop = flip change $ id

--------------------------------------------------------------------------------
-- | Change a single line for some other lines
Expand Down
6 changes: 5 additions & 1 deletion lib/Language/Haskell/Stylish/GHC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ module Language.Haskell.Stylish.GHC
-- * Unsafe getters
, unsafeGetRealSrcSpan
, getEndLineUnsafe
, getEndColumnUnsafe
, getStartLineUnsafe
-- * Standard settings
, baseDynFlags
Expand All @@ -33,7 +34,7 @@ import qualified Outputable as GHC
import PlatformConstants (PlatformConstants (..))
import SrcLoc (GenLocated (..), Located, RealLocated,
RealSrcSpan, SrcSpan (..), srcSpanEndLine,
srcSpanStartLine)
srcSpanStartLine, srcSpanEndCol)
import ToolSettings (ToolSettings (..))

unsafeGetRealSrcSpan :: Located a -> RealSrcSpan
Expand All @@ -47,6 +48,9 @@ getStartLineUnsafe = srcSpanStartLine . unsafeGetRealSrcSpan
getEndLineUnsafe :: Located a -> Int
getEndLineUnsafe = srcSpanEndLine . unsafeGetRealSrcSpan

getEndColumnUnsafe :: Located a -> Int
getEndColumnUnsafe = srcSpanEndCol . unsafeGetRealSrcSpan

dropAfterLocated :: Maybe (Located a) -> [RealLocated b] -> [RealLocated b]
dropAfterLocated loc xs = case loc of
Just (L (RealSrcSpan rloc) _) ->
Expand Down
135 changes: 135 additions & 0 deletions lib/Language/Haskell/Stylish/Step/Signature.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE RecordWildCards #-}
module Language.Haskell.Stylish.Step.Signature where

import RdrName (RdrName)
import SrcLoc (GenLocated (..), Located)
import GHC.Hs.Decls
import GHC.Hs.Binds
import GHC.Hs.Types
import GHC.Hs.Extension (GhcPs)

--------------------------------------------------------------------------------
import Language.Haskell.Stylish.Block
import Language.Haskell.Stylish.Step
import Language.Haskell.Stylish.Module
import Language.Haskell.Stylish.Editor (change, noop)
import Language.Haskell.Stylish.GHC (getStartLineUnsafe, getEndLineUnsafe, getEndColumnUnsafe)
import Language.Haskell.Stylish.Editor (Change, applyChanges)
import Language.Haskell.Stylish.Printer

-- TODO unify with type alias from Data.hs
type ChangeLine = Change String

data MaxColumns
= MaxColumns !Int
| NoMaxColumns
deriving (Show, Eq)

fits :: Int -> MaxColumns -> Bool
fits _ NoMaxColumns = True
fits v (MaxColumns limit) = v <= limit

data Config = Config
{ cMaxColumns :: MaxColumns
}

step :: Config -> Step
step cfg = makeStep "Signature" (\ls m -> applyChanges (changes cfg m) ls)

changes :: Config -> Module -> [ChangeLine]
changes cfg m = fmap (formatSignatureDecl cfg m) (topLevelFunctionSignatures m)

topLevelFunctionSignatures :: Module -> [Located SignatureDecl]
topLevelFunctionSignatures = queryModule @(Located (HsDecl GhcPs)) \case
L pos (SigD _ (TypeSig _ [name] (HsWC _ (HsIB _ (L _ funTy@(HsFunTy _ _ _ )))))) ->
[L pos $ MkSignatureDecl name (listParameters funTy) []]
L pos (SigD _ (TypeSig _ [name] (HsWC _ (HsIB _ (L _ (HsQualTy _ (L _ contexts) (L _ funTy))))))) ->
[L pos $ MkSignatureDecl name (listParameters funTy) (contexts >>= listContexts)]
_ -> []

listParameters :: HsType GhcPs -> [Located RdrName]
listParameters (HsFunTy _ (L _ arg2) (L _ arg3)) = listParameters arg2 <> listParameters arg3
listParameters (HsTyVar _ _promotionFlag name) = [name]
listParameters _ = []

listContexts :: Located (HsType GhcPs) -> [Located RdrName]
listContexts (L _ (HsTyVar _ _ name)) = [name]
listContexts (L _ (HsAppTy _ arg1 arg2)) = listContexts arg1 <> listContexts arg2
listContexts _ = []

data SignatureDecl = MkSignatureDecl
{ sigName :: Located RdrName
, sigParameters :: [Located RdrName]
, sigConstraints :: [Located RdrName]
}

formatSignatureDecl :: Config -> Module -> Located SignatureDecl -> ChangeLine
formatSignatureDecl cfg@Config{..} m ldecl@(L _ decl)
| fits declLength cMaxColumns = noop block
| otherwise = change block (const (printDecl cfg m decl))

where
block = Block (getStartLineUnsafe ldecl) (getEndLineUnsafe ldecl)
declLength = getEndColumnUnsafe ldecl

printDecl :: Config -> Module -> SignatureDecl -> Lines
printDecl Config{..} m MkSignatureDecl{..} = runPrinter_ printerConfig [] m do
printFirstLine
printSecondLine
printRemainingLines
where

----------------------------------------------------------------------------------------

printFirstLine =
putRdrName sigName >> space >> putText "::" >> newline

----------------------------------------------------------------------------------------

printSecondLine =
if hasConstraints then printConstraints
else printFirstParameter

printConstraints =
spaces 5 >> putText "("
>> (traverse (\ctr -> printConstraint ctr >> putText ", ") (init groupConstraints))
>> (printConstraint $ last groupConstraints)
>> putText ")" >> newline

groupConstraints = zip (dropEvery sigConstraints 2) (dropEvery (tail sigConstraints) 2)

printConstraint (tc, tp) = putRdrName tc >> space >> putRdrName tp

printFirstParameter =
spaces 5 >> (putRdrName $ head sigParameters) >> newline

----------------------------------------------------------------------------------------

printRemainingLines =
if hasConstraints then
printRemainingLine "=>" (head sigParameters)
>> traverse (printRemainingLine "->") (tail sigParameters)
else
traverse (printRemainingLine "->") (tail sigParameters)

printRemainingLine prefix parameter =
spaces 2 >> putText prefix >> space >> (putRdrName parameter) >> newline

----------------------------------------------------------------------------------------

printerConfig = PrinterConfig
{ columns = case cMaxColumns of
NoMaxColumns -> Nothing
MaxColumns n -> Just n
}

hasConstraints = not $ null sigConstraints

-- 99 problems :)
dropEvery :: [a] -> Int -> [a]
dropEvery xs n
| length xs < n = xs
| otherwise = take (n-1) xs ++ dropEvery (drop n xs) n
3 changes: 3 additions & 0 deletions stylish-haskell.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Library
Language.Haskell.Stylish.Step.Imports
Language.Haskell.Stylish.Step.ModuleHeader
Language.Haskell.Stylish.Step.LanguagePragmas
Language.Haskell.Stylish.Step.Signature
Language.Haskell.Stylish.Step.SimpleAlign
Language.Haskell.Stylish.Step.Squash
Language.Haskell.Stylish.Step.Tabs
Expand Down Expand Up @@ -137,6 +138,8 @@ Test-suite stylish-haskell-tests
Language.Haskell.Stylish.Step.ModuleHeader.Tests
Language.Haskell.Stylish.Step.LanguagePragmas
Language.Haskell.Stylish.Step.LanguagePragmas.Tests
Language.Haskell.Stylish.Step.Signature
Language.Haskell.Stylish.Step.Signature.Tests
Language.Haskell.Stylish.Step.SimpleAlign
Language.Haskell.Stylish.Step.SimpleAlign.Tests
Language.Haskell.Stylish.Step.Squash
Expand Down
144 changes: 144 additions & 0 deletions tests/Language/Haskell/Stylish/Step/Signature/Tests.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
module Language.Haskell.Stylish.Step.Signature.Tests
( tests
) where

import Language.Haskell.Stylish.Step.Signature
import Language.Haskell.Stylish.Tests.Util (assertSnippet, testStep)
import Test.Framework (Test, testGroup)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit (Assertion, (@=?))

tests :: Test
tests = testGroup "Language.Haskell.Stylish.Step.Signature.Tests"
[ testCase "do not wrap signature if it fits max column length" case00
, testCase "wrap signature if it does not fit max column length" case01
, testCase "how it behaves when there is a list of constraints" case02
-- , testCase "how it behaves when there is a explicit forall" case03
-- , testCase "how it behaves when there is a explicit forall" case04
-- , testCase "how it behaves when there is a large function in the argument" case05
]

config :: Int -> Config
config cMaxColumns = Config
{ cMaxColumns = MaxColumns cMaxColumns
}

case00 :: Assertion
case00 = expected @=? testStep (step $ config 80) input
where
input = unlines
[ "module Herp where"
, ""
, "fooBar :: a -> b -> a"
, "fooBar v _ = v"
]
expected = input

case01 :: Assertion
case01 = expected @=? testStep (step $ config 20) input
where
input = unlines
[ "module Herp where"
, ""
, "fooBar :: a -> b -> a"
, "fooBar v _ = v"
]
expected = unlines
[ "module Herp where"
, ""
, "fooBar ::"
, " a"
, " -> b"
, " -> a"
, "fooBar v _ = v"
]

case02 :: Assertion
case02 = expected @=? testStep (step $ config 20) input
where
input = unlines
[ "module Herp where"
, ""
, "fooBar :: (Eq a, Show b) => a -> b -> a"
, "fooBar v _ = v"
]
expected = unlines
[ "module Herp where"
, ""
, "fooBar ::"
, " (Eq a, Show b)"
, " => a"
, " -> b"
, " -> a"
, "fooBar v _ = v"
]

case03 :: Assertion
case03 = expected @=? testStep (step $ config 20) input
where
input = unlines
[ "module Herp where"
, ""
, "fooBar :: forall a . b. (Eq a, Show b) => a -> b -> a"
, "fooBar v _ = v"
]
expected = unlines
[ "module Herp where"
, ""
, "fooBar ::"
, " forall a . b."
, " (Eq a, Show b)"
, " => a"
, " -> b"
, " -> a"
, "fooBar v _ = v"
]

case04 :: Assertion
case04 = expected @=? testStep (step $ config 20) input
where
input = unlines
[ "module Herp where"
, ""
, "fooBar :: forall a . b. c. (Eq a, Show b, Ord c) => a -> b -> c -> a"
, "fooBar v _ _ = v"
]
expected = unlines
[ "module Herp where"
, ""
, "fooBar ::"
, " forall a . b. ("
, " Eq a"
, " , Show b"
, " , Ord c)"
, " )"
, " => a"
, " -> b"
, " -> a"
, "fooBar v _ = v"
]

case05 :: Assertion
case05 = expected @=? testStep (step $ config 20) input
where
input = unlines
[ "module Herp where"
, ""
, "fooBar :: => a -> (forall c. Eq c => c -> a -> a) -> a"
, "fooBar v _ = v"
]
expected = unlines
[ "module Herp where"
, ""
, "fooBar ::"
, " => a"
, " -> ( forall c. Eq c"
, " => c"
, " -> a"
, " -> a"
, " )"
, " -> a"
, "fooBar v _ = v"
]
2 changes: 2 additions & 0 deletions tests/TestSuite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import qualified Language.Haskell.Stylish.Step.Imports.Tests
import qualified Language.Haskell.Stylish.Step.Imports.FelixTests
import qualified Language.Haskell.Stylish.Step.ModuleHeader.Tests
import qualified Language.Haskell.Stylish.Step.LanguagePragmas.Tests
import qualified Language.Haskell.Stylish.Step.Signature.Tests
import qualified Language.Haskell.Stylish.Step.SimpleAlign.Tests
import qualified Language.Haskell.Stylish.Step.Squash.Tests
import qualified Language.Haskell.Stylish.Step.Tabs.Tests
Expand All @@ -34,6 +35,7 @@ main = defaultMain
, Language.Haskell.Stylish.Step.Imports.FelixTests.tests
, Language.Haskell.Stylish.Step.LanguagePragmas.Tests.tests
, Language.Haskell.Stylish.Step.ModuleHeader.Tests.tests
, Language.Haskell.Stylish.Step.Signature.Tests.tests
, Language.Haskell.Stylish.Step.SimpleAlign.Tests.tests
, Language.Haskell.Stylish.Step.Squash.Tests.tests
, Language.Haskell.Stylish.Step.Tabs.Tests.tests
Expand Down