(in-package "USER") ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;;; ;;; SPI method for Bayesian Inference ;;; ;;; Tom Dietterich Sun Oct 26 11:52:52 1997 ;;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; (defvar *variables* nil "List of all random variables in the problem") ;;; Every variable is assigned an integer value (based on its position ;;; in the *variables* list. This integer value is also its bit ;;; position in the varset bit-string representation. ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;;; ;;; varset class ;;; ;;; We need to manipulate sets of variables. It is convenient to ;;; manipulate them as bit vectors to make union and intersection ;;; computations efficient. ;;; (defstructure (varset (:constructor create-varset)) bits) (defun make-varset (var-list) ;; given a list of variables, construct a varset (let ((result 0)) (loop for var in var-list for bit-pos = (position var *variables*) do (cond ((null bit-pos) (error "Unknown variable ~S given to make-varset" var))) (setf result (logior result (ash 1 bit-pos)))) (create-varset :bits result))) (defmethod vars-from-varset ((v varset)) ;; give a varset as a bit vector, convert to a list of variable symbols (let ((bits (varset-bits v)) (answer nil)) (loop for var in *variables* unless (zerop bits) do (cond ((= (logand 1 bits) 1) (push var answer))) (setf bits (ash bits -1))) (nreverse answer))) (defmethod print-structure ((v varset) stream) ;; print the var set as a set (format stream "#" (vars-from-varset v))) (defmethod varset-union ((v varset) other) ;; compute a new varset that is the union of these two (create-varset :bits (logior (varset-bits v) (varset-bits other)))) (defmethod varset-intersection ((v varset) other) (create-varset :bits (logand (varset-bits v) (varset-bits other)))) (defmethod varset-complement ((v varset)) (create-varset :bits (lognot (varset-bits v)))) (defmethod varset-difference ((v varset) other) ;; remove all bits in other from v (create-varset :bits (logandc2 (varset-bits v) (varset-bits other)))) (defmethod varset-size ((v varset)) ;; return number of 1 bits (logcount (varset-bits v))) (defun compress-under-mask (item mask) ;; Create a new MASK whose bits are the bits from ITEM selected by MASK ;; All bits in ITEM whose corresponding bit in MASK is zero are deleted ;; Example: ITEM = 11001100 ;; MASK = 10101010 ;; Result: 00001010 (declare (fixnum item mask)) ;; Notes for novice Lisp programmers. ;; logand is bit-wise AND of two integers ;; logior is bit-wise inclusive OR of two integers ;; ash is arithmetic left shift (negative argument => right shift) (let ((result 0) (one (ash 1 (- (length *variables*) 1)))) (declare (fixnum result one)) (loop while (> one 0) do (cond ((> (logand one mask) 0) ; bit is set in mask (setf result (logior (ash result 1) (if (> (logand one item) 0) 1 0))))) (setf one (ash one -1))) result)) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;;; ;;; ptable class ;;; ;;; A ptable is a class probability table. It doesn't matter whether ;;; it is a conditional probability table or a joint probability table ;;; for purposes of these routines. A joint probability table will ;;; sum to 1, whereas in a conditional probability table, all cells ;;; corresponding to fixed values for the conditioning variables will ;;; sum to 1. ;;; ;;; The class probability table is implemented as a one-dimensional ;;; array. The lowest-numbered variable (i.e., the first element of ;;; *variables* is varied the fastest in the index calculation. (defstructure (ptable (:constructor create-ptable)) vars ; a varset giving the variables in this table contents) ; a float vector of the probabilities (defmethod make-ptable ((all list)) ;; as a list, a ptable has the format ;; (var-list . contents) ;; where var-list is a list of variables ;; and contents is a list of the contents in order. (destructuring-bind (vars . cont) all (create-ptable :vars (make-varset vars) :contents (make-array (expt 2 (length vars)) :element-type 'single-float :initial-contents cont)))) (defmethod make-ptable ((v varset)) (create-ptable :vars v :contents (make-array (expt 2 (varset-size v)) :element-type 'single-float :initial-element 0.0))) (defmethod print-structure ((pt ptable) stream) (format stream "#" (vars-from-varset (ptable-vars pt))) ; (print-ptable pt stream) ) (defmethod print-ptable ((pt ptable) &optional (stream t)) ;; display the ptable as a probability table (format stream "~S Probability~%" (reverse (vars-from-varset (ptable-vars pt)))) (let ((len (varset-size (ptable-vars pt)))) (loop for i from 0 below (length (ptable-contents pt)) do (format stream "~V,'0B ~10T~8,4F~%" len i (elt (ptable-contents pt) i))))) (defun deep-copy-ptable (pt) (create-ptable :vars (copy-varset (ptable-vars pt)) :contents (copy-array (ptable-contents pt)))) (defun ptable-variables (pt) (vars-from-varset (ptable-vars pt))) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;;; ;;; Operations on ptable's ;;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; (defun conformal-product (pt1 pt2) ;; creates a new ptable that is the conformal product of pt1 and pt2 (cond ((zerop (varset-bits (varset-intersection (ptable-vars pt1) (ptable-vars pt2)))) (error "Two ptables must share variables in order to take conformal product"))) (let* ((newvars (varset-union (ptable-vars pt1) (ptable-vars pt2))) (result (make-ptable newvars)) (n (length (ptable-contents result))) (left-mask (compress-under-mask (varset-bits (ptable-vars pt1)) (varset-bits newvars))) (right-mask (compress-under-mask (varset-bits (ptable-vars pt2)) (varset-bits newvars)))) (loop for i from 0 below n do (setf (elt (ptable-contents result) i) (* (elt (ptable-contents pt1) (compress-under-mask i left-mask)) (elt (ptable-contents pt2) (compress-under-mask i right-mask))))) result)) (defmethod sum-over ((pt ptable) var-list) ;; create a new ptable by summing over all values of the variables ;; in the given vars (a varset). ;; ;; We construct a mask so that we can compress each index into pt ;; under this mask to get the index into the answer pt. ;; (let* ((keep (varset-difference (ptable-vars pt) (make-varset var-list))) (keep-mask (compress-under-mask (varset-bits keep) (varset-bits (ptable-vars pt)))) (answer (make-ptable keep)) (n (length (ptable-contents pt)))) (loop for i from 0 below n do (incf (elt (ptable-contents answer) (compress-under-mask i keep-mask)) (elt (ptable-contents pt) i))) answer)) (defmethod project ((pt ptable) var value) ;; destructively modify the current ptable to eliminate all rows ;; where VAR is not equal to VALUE. ;; The column corresponding to VAR is also removed. ;; The resulting table is NOT normalized ;; VALUE must be either 0 or 1. (let* ((observed-varset (make-varset (list var))) (observed-bits (varset-bits observed-varset)) (old-varset (ptable-vars pt)) (old-bits (varset-bits old-varset))) (when (member var (vars-from-varset old-varset)) ;; this table contains this variable ;; allocate a new array and construct a mask that will compress ;; old indexes into new indexes (let* ((new-varset (varset-difference old-varset observed-varset)) (new-bits (varset-bits new-varset)) (mask (compress-under-mask new-bits old-bits)) (observed-mask (compress-under-mask observed-bits old-bits)) (new-contents (make-array (expt 2 (varset-size new-varset)) :element-type 'single-float :initial-element 0.0))) (loop for i from 0 below (length (ptable-contents pt)) do ;; does this index correspond to the observed value of var? (let ((bit (logand observed-mask i))) (when (or (and (= value 1) (> bit 0)) (and (= value 0) (= bit 0))) ;; this is the observed value (setf (elt new-contents (compress-under-mask i mask)) (elt (ptable-contents pt) i))))) (setf (ptable-contents pt) new-contents) (setf (ptable-vars pt) new-varset)))) pt) (defmethod normalize ((pt ptable) cond-vars-list) ;; normalize the values in the ptable so that they sum to 1 for each ;; distinct setting of the variables in vars. If pt is a joint ;; distribution, this conditions it on the values of vars. ;; ;; Our strategy is to create a new ptable by summing over the ;; variables that we are not using for conditioning. This gives a ;; ptable in which each cell corresponds to a distinct setting of the ;; cond-vars and whose value is the normalizer. ;; ;; Then we go back through the original table and apply the same ;; mask to access the normalizer and divide by it. (let* ((cond-vars (make-varset cond-vars-list)) (all-vars (ptable-vars pt)) (keep-vars (varset-difference (ptable-vars pt) cond-vars)) (normalizers (sum-over pt (vars-from-varset keep-vars))) (mask (compress-under-mask (varset-bits cond-vars) (varset-bits all-vars))) (contents (ptable-contents pt))) (loop for i from 0 below (length contents) do (let ((normalizer (elt (ptable-contents normalizers) (compress-under-mask i mask)))) (setf (elt contents i) (/ (elt contents i) normalizer)))) pt)) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;;; ;;; belief-net class ;;; (defstructure (belief-net (:constructor create-belief-net)) (nodes nil) ; a list of ptable's (known-variables nil) ; variables whose values have been tell'ed (known-values nil) ; alist of known variables and values (variables nil)) ; all variables in the network (defun deep-copy-belief-net (bn) ;; make a copy of a bayes net (create-belief-net :nodes (loop for node in (belief-net-nodes bn) collect (deep-copy-ptable node)) :known-variables (copy-tree (belief-net-known-variables bn)) :known-values (copy-tree (belief-net-known-values bn)) :variables (copy-tree (belief-net-variables bn)))) (defun make-belief-net (list) ;; A belief-net can be entered in list form as follows ;; (node1 node2 ...) ;; where each node has the format ;; (var cpt) ;; where the cpt has a form suitable for giving to MAKE-PTABLE ;; First: build a list of all of the variables and bind it to ;; *variables* (setf *variables* (loop for (var . nil) in list collect var)) (let* ((nodes (loop for (var cpt) in list collect (make-ptable cpt)))) (declare (special *variables*)) ;; now compute the children of each node (create-belief-net :nodes nodes :variables *variables*))) ;; convention: 0 = false 1 = true (defparameter *test-net* '((cloudy ((cloudy) .5 ; f .5 ; t )) (sprinkler ((sprinkler cloudy) .5 ; f f .9 ; f t .5 ; t f .1 ; t t )) (rain ((rain cloudy) .8 ; f f .2 ; f t .2 ; t f .8 ; t t )) (wet-grass ((wet-grass rain sprinkler) 1.0 ; f f f 0.1 ; f f t 0.1 ; f t f 0.01 ; f t t 0.0 ; t f f 0.9 ; t f t 0.9 ; t t f 0.99 ; t t t )))) ;;; ;;; convention: 0 => ok, 1=> bad (defparameter *car-net* '((SparkPlugs ((SparkPlugs) 0.9998 0.0002)) (Distributor ((Distributor) 0.999 0.001)) (FuelPump ((FuelPump) 0.999 0.001)) (Leak2 ((Leak2) 0.9999 0.0001)) (Starter ((Starter) 0.99 0.01)) (BatteryAge ((BatteryAge) 0.7 0.3)) (BatteryState ((BatteryState BatteryAge) .99 .80 .01 .20)) (Alternator ((Alternator) 0.9995 0.0005)) (FanBelt ((FanBelt) 0.995 0.005)) (Leak ((Leak) 0.9999 0.0001)) (Charge ((Charge Leak FanBelt Alternator) 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0)) (BatteryPower ((BatteryPower Charge BatteryState) 1.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 )) (EngineCranks ((EngineCranks BatteryPower Starter Leak2) 1.0 0.0 0.0 0.0 0.2 0.0 0.0 0.0 0.0 1.0 1.0 1.0 0.8 1.0 1.0 1.0 )) (Starts ((GasInTank Starts EngineCranks FuelPump Distributor SparkPlugs) 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 )) (Radio ((Radio BatteryPower) 0.9 0.1 0.1 0.9)) (GasInTank ((GasInTank) 0.5 0.5)) (GasGauge ((GasGauge GasInTank BatteryPower) 1.0 0.1 0.0 0.0 0.0 0.9 1.0 1.0)) (Lights ((Lights BatteryPower) 0.9 0.0 0.1 1.0)))) ;;; ;;; convention: 0 => false 1=> true (defparameter *burglar-net* '((burglary ((Burglary) 0.999 0.001)) (Earthquake ((earthquake) .998 .002)) (alarm ((alarm earthquake burglary) .999 ; f f f .06 ; f f t .71 ; f t f .05 ; f t t .001 ; t f f .94 ; t f t .29 ; t t f .95 ; t t t )) (john-calls ((john-calls alarm) .95 ; f f .10 ; f t .05 ; t f .90 ; t t )) (mary-calls ((mary-calls alarm) .99 ; f f .30 ; f t .01 ; t f .70 ; t t )))) (defmethod tell ((bn belief-net) var value) ) (defmethod ask ((bn belief-net) var) )