Skip to content

Commit

Permalink
Get rid of hashing causing non-determinismistic compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewDaggitt committed Nov 7, 2024
1 parent c34d245 commit 75549fc
Show file tree
Hide file tree
Showing 16 changed files with 120 additions and 120 deletions.
5 changes: 2 additions & 3 deletions vehicle/src/Vehicle/Backend/Queries/PostProcessing.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import Control.Monad.Reader (MonadReader (..))
import Control.Monad.State (get)
import Data.Bifunctor (Bifunctor (..))
import Data.Foldable (foldlM)
import Data.HashMap.Strict qualified as HashMap
import Data.LinkedHashMap qualified as LinkedHashMap
import Data.List (sort, sortOn)
import Data.List.NonEmpty (NonEmpty (..))
Expand Down Expand Up @@ -96,7 +95,7 @@ reconstructNetworkTensorVars ::
reconstructNetworkTensorVars GlobalCtx {..} solutions = do
let networkApplicationInfos = snd <$> LinkedHashMap.toList networkApplications
let networkVariables = Set.fromList $ concatMap (\r -> [inputVariable r, outputVariable r]) networkApplicationInfos
let allTensorVars = filter (\(var, _) -> var `Set.member` networkVariables) $ HashMap.toList tensorVariableInfo
let allTensorVars = filter (\(var, _) -> var `Set.member` networkVariables) $ Map.toList tensorVariableInfo
let networkTensorVars = sortOn fst allTensorVars
let mkStep (var, TensorVariableInfo {..}) = ReconstructTensor OtherVariable tensorVariableShape var elementVariables
return $ foldr (\v -> (mkStep v :)) solutions networkTensorVars
Expand Down Expand Up @@ -282,7 +281,7 @@ compileQueryVariables ::
compileQueryVariables globalCtx compileVariable metaNetworkApps assertions = do
-- Compute the set of new input and output variables
let initialState = IndexingState mempty mempty mempty
let tensorVars = HashMap.toList (tensorVariableInfo globalCtx)
let tensorVars = sortOn fst $ Map.toList (tensorVariableInfo globalCtx)
let usedNetworkTensorVars = Set.fromList $ concatMap (\x -> [inputVariable x, outputVariable x]) metaNetworkApps
let compileVars = compileTensorVariable compileVariable (globalBoundVarCtx globalCtx) usedNetworkTensorVars
indexingState@IndexingState {..} <- foldlM compileVars initialState tensorVars
Expand Down
23 changes: 12 additions & 11 deletions vehicle/src/Vehicle/Backend/Queries/UserVariableElimination/Core.hs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

{-# HLINT ignore "Use fewer imports" #-}
module Vehicle.Backend.Queries.UserVariableElimination.Core where

import Control.DeepSeq (NFData)
import Control.Monad.Reader (MonadReader (..))
import Control.Monad.State (MonadState (..), gets)
import Data.Aeson (FromJSON, ToJSON)
import Data.Bifunctor (Bifunctor (..))
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as HashMap (insert, lookup)
import Data.LinkedHashMap (LinkedHashMap)
import Data.LinkedHashMap qualified as LinkedHashMap
import Data.List.NonEmpty qualified as NonEmpty
Expand Down Expand Up @@ -70,7 +71,7 @@ data PropertyMetaData = PropertyMetaData

data GlobalCtx = GlobalCtx
{ globalBoundVarCtx :: !(GenericBoundCtx Name),
tensorVariableInfo :: !(HashMap TensorVariable TensorVariableInfo),
tensorVariableInfo :: !(Map TensorVariable TensorVariableInfo),
networkApplications :: !(LinkedHashMap NetworkApplication NetworkApplicationReplacement)
}

Expand Down Expand Up @@ -105,7 +106,7 @@ addUserVarToGlobalContext userVarName shape GlobalCtx {..} = do
let newGlobalCtx =
GlobalCtx
{ globalBoundVarCtx = addVectorVarToBoundVarCtx userVarName reducedUseVars globalBoundVarCtx,
tensorVariableInfo = HashMap.insert userVar variableInfo tensorVariableInfo,
tensorVariableInfo = Map.insert userVar variableInfo tensorVariableInfo,
..
}
(userVar, newGlobalCtx)
Expand Down Expand Up @@ -168,8 +169,8 @@ addNetworkApplicationToGlobalCtx app@(networkName, _) networkInfo GlobalCtx {..}
}

let newTensorVariableInfo =
HashMap.insert inputVar inputVarInfo $
HashMap.insert outputVar outputVarInfo tensorVariableInfo
Map.insert inputVar inputVarInfo $
Map.insert outputVar outputVarInfo tensorVariableInfo

let newGlobalCtx =
GlobalCtx
Expand Down Expand Up @@ -303,10 +304,10 @@ getTensorVariableShape var = do
return (tensorVariableShape info)

getRationalVariable :: (MonadState GlobalCtx m) => Lv -> m ElementVariable
getRationalVariable lv = do
getRationalVariable var = do
ctx <- get
case HashMap.lookup lv (tensorVariableInfo ctx) of
Nothing -> return lv
case Map.lookup var (tensorVariableInfo ctx) of
Nothing -> return var
Just info -> do
let rvs = elementVariables info
case rvs of
Expand All @@ -315,7 +316,7 @@ getRationalVariable lv = do

getTensorVariableInfo :: GlobalCtx -> TensorVariable -> TensorVariableInfo
getTensorVariableInfo GlobalCtx {..} var = do
case HashMap.lookup var tensorVariableInfo of
case Map.lookup var tensorVariableInfo of
Just info -> info
Nothing ->
developerError $
Expand All @@ -327,7 +328,7 @@ getReducedVariablesFor globalCtx var = elementVariables $ getTensorVariableInfo
getReducedVariableExprFor :: (MonadState GlobalCtx m, MonadLogger m) => Lv -> m (Maybe (Value Builtin))
getReducedVariableExprFor var = do
ctx <- get
return $ reducedVarExpr <$> HashMap.lookup var (tensorVariableInfo ctx)
return $ reducedVarExpr <$> Map.lookup var (tensorVariableInfo ctx)

reduceTensorExpr ::
GlobalCtx ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
// Metadata:
// - Marabou query format version: unknown
// - Vehicle version: 0.15.0+dev
-x0 +y0 <= -0.1
+x5 -y5 = 0.0
+x6 -y6 = 0.0
-x0 +y2 <= -0.1
+x5 -y0 = 0.0
+x6 -y1 = 0.0
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
// Metadata:
// - Marabou query format version: unknown
// - Vehicle version: 0.15.0+dev
+x4 -y4 <= -0.1
+x5 -y5 = 0.0
+x6 -y6 = 0.0
+x4 -y6 <= -0.1
+x5 -y0 = 0.0
+x6 -y1 = 0.0
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
// Metadata:
// - Marabou query format version: unknown
// - Vehicle version: 0.15.0+dev
+x0 -y0 <= -0.1
+x5 -y5 = 0.0
+x6 -y6 = 0.0
+x0 -y2 <= -0.1
+x5 -y0 = 0.0
+x6 -y1 = 0.0
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
// Metadata:
// - Marabou query format version: unknown
// - Vehicle version: 0.15.0+dev
-x1 +y1 <= -0.1
+x5 -y5 = 0.0
+x6 -y6 = 0.0
-x1 +y3 <= -0.1
+x5 -y0 = 0.0
+x6 -y1 = 0.0
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
// Metadata:
// - Marabou query format version: unknown
// - Vehicle version: 0.15.0+dev
+x1 -y1 <= -0.1
+x5 -y5 = 0.0
+x6 -y6 = 0.0
+x1 -y3 <= -0.1
+x5 -y0 = 0.0
+x6 -y1 = 0.0
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
// Metadata:
// - Marabou query format version: unknown
// - Vehicle version: 0.15.0+dev
-x2 +y2 <= -0.1
+x5 -y5 = 0.0
+x6 -y6 = 0.0
-x2 +y4 <= -0.1
+x5 -y0 = 0.0
+x6 -y1 = 0.0
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
// Metadata:
// - Marabou query format version: unknown
// - Vehicle version: 0.15.0+dev
+x2 -y2 <= -0.1
+x5 -y5 = 0.0
+x6 -y6 = 0.0
+x2 -y4 <= -0.1
+x5 -y0 = 0.0
+x6 -y1 = 0.0
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
// Metadata:
// - Marabou query format version: unknown
// - Vehicle version: 0.15.0+dev
-x3 +y3 <= -0.1
+x5 -y5 = 0.0
+x6 -y6 = 0.0
-x3 +y5 <= -0.1
+x5 -y0 = 0.0
+x6 -y1 = 0.0
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
// Metadata:
// - Marabou query format version: unknown
// - Vehicle version: 0.15.0+dev
+x3 -y3 <= -0.1
+x5 -y5 = 0.0
+x6 -y6 = 0.0
+x3 -y5 <= -0.1
+x5 -y0 = 0.0
+x6 -y1 = 0.0
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
// Metadata:
// - Marabou query format version: unknown
// - Vehicle version: 0.15.0+dev
-x4 +y4 <= -0.1
+x5 -y5 = 0.0
+x6 -y6 = 0.0
-x4 +y6 <= -0.1
+x5 -y0 = 0.0
+x6 -y1 = 0.0
Loading

0 comments on commit 75549fc

Please sign in to comment.