{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Test.Cardano.Ledger.Shelley.Serialisation.GoldenUtils (
  checkEncoding,
  checkEncodingCBOR,
  checkEncodingCBORAnnotated,
  ToTokens (..),
  roundTripFailure,
  checkEncodingCBORDecodeFailure,
) where

import Cardano.Ledger.Binary (
  Annotator,
  DecCBOR (..),
  DecoderError,
  EncCBOR (..),
  EncCBORGroup (..),
  Encoding,
  ToCBOR (..),
  Tokens (..),
  Version,
  decodeFullAnnotator,
  decodeFullDecoder,
  decodeTerm,
  encCBOR,
  fromPlainEncoding,
  serialize,
  serialize',
 )
import qualified Codec.CBOR.Encoding as CBOR (Encoding (..))
import Control.Exception (throwIO)
import Control.Monad (unless)
import qualified Data.ByteString.Lazy as BSL (ByteString)
import Data.String (fromString)
import GHC.Stack
import qualified Prettyprinter as Pretty
import Test.Cardano.Ledger.Binary.TreeDiff (ansiDocToString, diffExpr)
import Test.Tasty (TestTree, testGroup)
import Test.Tasty.HUnit (Assertion, assertFailure, testCase, (@?=))

expectDecodingSuccess :: (HasCallStack, Show a, Eq a) => (a -> Either DecoderError a) -> a -> IO ()
expectDecodingSuccess :: forall a.
(HasCallStack, Show a, Eq a) =>
(a -> Either DecoderError a) -> a -> IO ()
expectDecodingSuccess a -> Either DecoderError a
action a
x =
  case a -> Either DecoderError a
action a
x of
    Left DecoderError
e -> [Char] -> IO ()
forall a. HasCallStack => [Char] -> IO a
assertFailure ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char]
"could not decode serialization of " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ a -> [Char]
forall a. Show a => a -> [Char]
show a
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
", " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ DecoderError -> [Char]
forall a. Show a => a -> [Char]
show DecoderError
e
    Right a
y -> a
y a -> a -> IO ()
forall a. (Eq a, Show a, HasCallStack) => a -> a -> IO ()
@?= a
x

expectDecodingFailure :: (HasCallStack, Show a) => (a -> Either DecoderError a) -> a -> IO ()
expectDecodingFailure :: forall a.
(HasCallStack, Show a) =>
(a -> Either DecoderError a) -> a -> IO ()
expectDecodingFailure a -> Either DecoderError a
action a
x =
  case a -> Either DecoderError a
action a
x of
    Left DecoderError
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Right a
_ -> [Char] -> IO ()
forall a. HasCallStack => [Char] -> IO a
assertFailure ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char]
"Did not expect successful decoding of " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ a -> [Char]
forall a. Show a => a -> [Char]
show a
x

roundtrip ::
  Version ->
  (a -> Encoding) ->
  (BSL.ByteString -> Either DecoderError a) ->
  a ->
  Either DecoderError a
roundtrip :: forall a.
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> Either DecoderError a
roundtrip Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode = ByteString -> Either DecoderError a
decode (ByteString -> Either DecoderError a)
-> (a -> ByteString) -> a -> Either DecoderError a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Version -> Encoding -> ByteString
forall a. EncCBOR a => Version -> a -> ByteString
serialize Version
v (Encoding -> ByteString) -> (a -> Encoding) -> a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Encoding
encode

roundTripSuccess ::
  (Show a, Eq a) =>
  Version ->
  (a -> Encoding) ->
  (BSL.ByteString -> Either DecoderError a) ->
  a ->
  Assertion
roundTripSuccess :: forall a.
(Show a, Eq a) =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTripSuccess Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode a
x = (a -> Either DecoderError a) -> a -> IO ()
forall a.
(HasCallStack, Show a, Eq a) =>
(a -> Either DecoderError a) -> a -> IO ()
expectDecodingSuccess (Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> Either DecoderError a
forall a.
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> Either DecoderError a
roundtrip Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode) a
x

roundTripFailure ::
  Show a => Version -> (a -> Encoding) -> (BSL.ByteString -> Either DecoderError a) -> a -> Assertion
roundTripFailure :: forall a.
Show a =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTripFailure Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode a
x = (a -> Either DecoderError a) -> a -> IO ()
forall a.
(HasCallStack, Show a) =>
(a -> Either DecoderError a) -> a -> IO ()
expectDecodingFailure (Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> Either DecoderError a
forall a.
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> Either DecoderError a
roundtrip Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode) a
x

checkEncoding ::
  (HasCallStack, Show a, Eq a) =>
  Version ->
  (a -> Encoding) ->
  (BSL.ByteString -> Either DecoderError a) ->
  String ->
  a ->
  ToTokens ->
  TestTree
checkEncoding :: forall a.
(HasCallStack, Show a, Eq a) =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> [Char]
-> a
-> ToTokens
-> TestTree
checkEncoding Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode [Char]
name a
x ToTokens
t = Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
forall a.
HasCallStack =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
checkEncodingWithRoundtrip Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
forall a.
(Show a, Eq a) =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTripSuccess [Char]
name a
x ToTokens
t

checkEncodingWithRoundtrip ::
  HasCallStack =>
  Version ->
  (a -> Encoding) ->
  (BSL.ByteString -> Either DecoderError a) ->
  (Version -> (a -> Encoding) -> (BSL.ByteString -> Either DecoderError a) -> a -> Assertion) ->
  String ->
  a ->
  ToTokens ->
  TestTree
checkEncodingWithRoundtrip :: forall a.
HasCallStack =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
checkEncodingWithRoundtrip Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTrip [Char]
name a
x ToTokens
t =
  [Char] -> IO () -> TestTree
testCase [Char]
testName (IO () -> TestTree) -> IO () -> TestTree
forall a b. (a -> b) -> a -> b
$ do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString
expectedBinary ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
actualBinary) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Term
expectedTerms <- Text -> ByteString -> IO Term
getTerms Text
"expected" ByteString
expectedBinary
      Term
actualTerms <- Text -> ByteString -> IO Term
getTerms Text
"actual" ByteString
actualBinary
      [Char] -> IO ()
forall a. HasCallStack => [Char] -> IO a
assertFailure ([Char] -> IO ())
-> (Doc AnsiStyle -> [Char]) -> Doc AnsiStyle -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Doc AnsiStyle -> [Char]
ansiDocToString (Doc AnsiStyle -> IO ()) -> Doc AnsiStyle -> IO ()
forall a b. (a -> b) -> a -> b
$
        [Doc AnsiStyle] -> Doc AnsiStyle
forall ann. [Doc ann] -> Doc ann
Pretty.vsep
          [ Doc AnsiStyle
"Serialization did not match:"
          , Int -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Int -> Doc ann -> Doc ann
Pretty.indent Int
2 (Doc AnsiStyle -> Doc AnsiStyle) -> Doc AnsiStyle -> Doc AnsiStyle
forall a b. (a -> b) -> a -> b
$ Term -> Term -> Doc AnsiStyle
forall a. ToExpr a => a -> a -> Doc AnsiStyle
diffExpr Term
expectedTerms Term
actualTerms
          ]
    Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTrip Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode a
x
  where
    getTerms :: Text -> ByteString -> IO Term
getTerms Text
lbl = (DecoderError -> IO Term)
-> (Term -> IO Term) -> Either DecoderError Term -> IO Term
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either DecoderError -> IO Term
forall e a. Exception e => e -> IO a
throwIO Term -> IO Term
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either DecoderError Term -> IO Term)
-> (ByteString -> Either DecoderError Term)
-> ByteString
-> IO Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Version
-> Text
-> (forall s. Decoder s Term)
-> ByteString
-> Either DecoderError Term
forall a.
Version
-> Text
-> (forall s. Decoder s a)
-> ByteString
-> Either DecoderError a
decodeFullDecoder Version
v Text
lbl Decoder s Term
forall s. Decoder s Term
decodeTerm
    expectedBinary :: ByteString
expectedBinary = Version -> ToTokens -> ByteString
forall a. EncCBOR a => Version -> a -> ByteString
serialize Version
v ToTokens
t
    actualBinary :: ByteString
actualBinary = Version -> Encoding -> ByteString
forall a. EncCBOR a => Version -> a -> ByteString
serialize Version
v (Encoding -> ByteString) -> Encoding -> ByteString
forall a b. (a -> b) -> a -> b
$ a -> Encoding
encode a
x
    testName :: [Char]
testName = [Char]
"golden_serialize_" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
name

checkEncodingCBORDecodeFailure ::
  (HasCallStack, DecCBOR a, EncCBOR a, Show a) =>
  Version ->
  String ->
  a ->
  ToTokens ->
  TestTree
checkEncodingCBORDecodeFailure :: forall a.
(HasCallStack, DecCBOR a, EncCBOR a, Show a) =>
Version -> [Char] -> a -> ToTokens -> TestTree
checkEncodingCBORDecodeFailure Version
v [Char]
name a
x ToTokens
t =
  let d :: ByteString -> Either DecoderError a
d = Version
-> Text
-> (forall s. Decoder s a)
-> ByteString
-> Either DecoderError a
forall a.
Version
-> Text
-> (forall s. Decoder s a)
-> ByteString
-> Either DecoderError a
decodeFullDecoder Version
v ([Char] -> Text
forall a. IsString a => [Char] -> a
fromString [Char]
name) Decoder s a
forall s. Decoder s a
forall a s. DecCBOR a => Decoder s a
decCBOR
   in Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
forall a.
HasCallStack =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
checkEncodingWithRoundtrip Version
v a -> Encoding
forall a. EncCBOR a => a -> Encoding
encCBOR ByteString -> Either DecoderError a
d Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
forall a.
Show a =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTripFailure [Char]
name a
x ToTokens
t

checkEncodingCBOR ::
  (HasCallStack, DecCBOR a, EncCBOR a, Show a, Eq a) =>
  Version ->
  String ->
  a ->
  ToTokens ->
  TestTree
checkEncodingCBOR :: forall a.
(HasCallStack, DecCBOR a, EncCBOR a, Show a, Eq a) =>
Version -> [Char] -> a -> ToTokens -> TestTree
checkEncodingCBOR Version
v [Char]
name a
x ToTokens
t =
  let d :: ByteString -> Either DecoderError a
d = Version
-> Text
-> (forall s. Decoder s a)
-> ByteString
-> Either DecoderError a
forall a.
Version
-> Text
-> (forall s. Decoder s a)
-> ByteString
-> Either DecoderError a
decodeFullDecoder Version
v ([Char] -> Text
forall a. IsString a => [Char] -> a
fromString [Char]
name) Decoder s a
forall s. Decoder s a
forall a s. DecCBOR a => Decoder s a
decCBOR
   in Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
forall a.
HasCallStack =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
checkEncodingWithRoundtrip Version
v a -> Encoding
forall a. EncCBOR a => a -> Encoding
encCBOR ByteString -> Either DecoderError a
d Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
forall a.
(Show a, Eq a) =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTripSuccess [Char]
name a
x ToTokens
t

checkEncodingCBORAnnotated ::
  (HasCallStack, DecCBOR (Annotator a), DecCBOR a, ToCBOR a, Show a, Eq a) =>
  Version ->
  String ->
  a ->
  ToTokens ->
  TestTree
checkEncodingCBORAnnotated :: forall a.
(HasCallStack, DecCBOR (Annotator a), DecCBOR a, ToCBOR a, Show a,
 Eq a) =>
Version -> [Char] -> a -> ToTokens -> TestTree
checkEncodingCBORAnnotated Version
v [Char]
name a
x ToTokens
t =
  let dAnn :: ByteString -> Either DecoderError a
dAnn = Version
-> Text
-> (forall s. Decoder s (Annotator a))
-> ByteString
-> Either DecoderError a
forall a.
Version
-> Text
-> (forall s. Decoder s (Annotator a))
-> ByteString
-> Either DecoderError a
decodeFullAnnotator Version
v ([Char] -> Text
forall a. IsString a => [Char] -> a
fromString [Char]
name) Decoder s (Annotator a)
forall s. Decoder s (Annotator a)
forall a s. DecCBOR a => Decoder s a
decCBOR
      d :: ByteString -> Either DecoderError a
d = Version
-> Text
-> (forall s. Decoder s a)
-> ByteString
-> Either DecoderError a
forall a.
Version
-> Text
-> (forall s. Decoder s a)
-> ByteString
-> Either DecoderError a
decodeFullDecoder Version
v ([Char] -> Text
forall a. IsString a => [Char] -> a
fromString [Char]
name) Decoder s a
forall s. Decoder s a
forall a s. DecCBOR a => Decoder s a
decCBOR
      annTokens :: ToTokens
annTokens = (Tokens -> Tokens) -> ToTokens
T ((Tokens -> Tokens) -> ToTokens) -> (Tokens -> Tokens) -> ToTokens
forall a b. (a -> b) -> a -> b
$ ByteString -> Tokens -> Tokens
TkEncoded (ByteString -> Tokens -> Tokens) -> ByteString -> Tokens -> Tokens
forall a b. (a -> b) -> a -> b
$ Version -> ToTokens -> ByteString
forall a. EncCBOR a => Version -> a -> ByteString
serialize' Version
v ToTokens
t
   in [Char] -> [TestTree] -> TestTree
testGroup
        [Char]
"with and without Annotator"
        [ Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
forall a.
HasCallStack =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
checkEncodingWithRoundtrip Version
v (Encoding -> Encoding
fromPlainEncoding (Encoding -> Encoding) -> (a -> Encoding) -> a -> Encoding
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Encoding
forall a. ToCBOR a => a -> Encoding
toCBOR) ByteString -> Either DecoderError a
dAnn Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
forall a.
(Show a, Eq a) =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTripSuccess [Char]
name a
x ToTokens
annTokens
        , Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
forall a.
HasCallStack =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
checkEncodingWithRoundtrip Version
v (Encoding -> Encoding
fromPlainEncoding (Encoding -> Encoding) -> (a -> Encoding) -> a -> Encoding
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Encoding
forall a. ToCBOR a => a -> Encoding
toCBOR) ByteString -> Either DecoderError a
d Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
forall a.
(Show a, Eq a) =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTripSuccess [Char]
name a
x ToTokens
annTokens
        ]

data ToTokens where
  T :: (Tokens -> Tokens) -> ToTokens
  S :: EncCBOR a => a -> ToTokens
  G :: EncCBORGroup a => a -> ToTokens
  Plus :: ToTokens -> ToTokens -> ToTokens

instance EncCBOR ToTokens where
  encCBOR :: ToTokens -> Encoding
encCBOR (T Tokens -> Tokens
xs) = Encoding -> Encoding
fromPlainEncoding ((Tokens -> Tokens) -> Encoding
CBOR.Encoding Tokens -> Tokens
xs)
  encCBOR (S a
s) = a -> Encoding
forall a. EncCBOR a => a -> Encoding
encCBOR a
s
  encCBOR (G a
g) = a -> Encoding
forall a. EncCBORGroup a => a -> Encoding
encCBORGroup a
g
  encCBOR (Plus ToTokens
a ToTokens
b) = ToTokens -> Encoding
forall a. EncCBOR a => a -> Encoding
encCBOR ToTokens
a Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> ToTokens -> Encoding
forall a. EncCBOR a => a -> Encoding
encCBOR ToTokens
b

instance Semigroup ToTokens where
  <> :: ToTokens -> ToTokens -> ToTokens
(<>) = ToTokens -> ToTokens -> ToTokens
Plus

instance Monoid ToTokens where
  mempty :: ToTokens
mempty = (Tokens -> Tokens) -> ToTokens
T Tokens -> Tokens
forall a. a -> a
id