I've learned. I'll share.

January 15, 2008

List Monad in Ruby and Python

Recently, I wrote about ways to add something like Haskell's do syntax to programming languages like python and ruby. The python trick used bidirectional generators and the ruby trick used callcc. I have since been notified that there is one serious problem with both of them: the bindee (the function given to bind) can only be called once. For many monads, this limitation is acceptable, but for some, it means you can't use that nice syntax.

Take the List monad, for example, in ruby:

class ListMonad < Monad
  attr_accessor :array

  def initialize(array)
    @array = array
  end

  def bind(bindee)
    ListMonad.new(concat(array.map { |val| maybeUnit(bindee.call(val)).array }))
  end

  def self.unit(val)
    self.new([val])
  end

  def to_s
    joined = array.join(', ')
    "ListMonad([#{joined}])"
  end
end

## I'm hoping this if faster than using fold (inject)
def concat(arrays)
  all = []
  for array in arrays
    all += array
  end 
  all
end

class Array
  def bind(bindee)
    ListMonad.new(self).bind(bindee)
  end

  def bindb(&block)
    ListMonad.new(self).bindb(&block)
  end
end
Ruby gives us the ability to extend Array to get a pretty sweet looking use of the monad:
list1 = with_monad ListMonad do
  x =xx    [1,2,3]
  y =xx [10,20,30]
  x + y
end

It looks great, but list1 should be ListMonad([11, 21, 31, 12, 22, 32, 13, 23, 33]), be we get ListMonad([11]) instead.

Traditional binding works fine:

list2 = [1, 2, 3].bindb do |x|
  [10, 20, 30].bindb do |y|
    x + y
  end
end

list2 is ListMonad([11, 21, 31, 12, 22, 32, 13, 23, 33]). So, we know our monad is fine, but our with_monad doesn't work.

We have the same problem in python:

class ListMonad(Monad):
    def __init__(self, vals):
        self.vals = vals

    def bind(self, bindee):
        return self.__class__(concat((self.maybeUnit(bindee(val)) for val in self.vals)))

    @classmethod
    def unit(cls, val):
        return cls([val])

    def __repr__(self):
        return "ListMonad(%s)" % self.vals

    def __iter__(self):
        return iter(self.vals)

def concat(lists):
    return [val for lst in lists
                for val in lst]

@do(ListMonad)
def with_list_monad():
    x = yield    [1, 2, 3]
    y = yield [10, 20, 30]
    mreturn(x + y)

list1 = with_list_monad()

def with_list_monad_binding():
    return \
    ListMonad([  1, 2, 3])  >> (lambda x:
    ListMonad([10, 20, 30]) >> (lambda y:
    x + y))

list2 = with_list_monad_binding()
    

Again, it looks good, but list1 is ListMonad([11, None, None, None, None]) while list2 is ListMonad([11, 21, 31, 12, 22, 32, 13, 23, 33]). The result of list1 sure is bizarre.

Luckily, I have found a solution for ruby, and a bad hack that could be construed as a solution for python if you ignore some issues.

In Ruby, the trick is to monkey with the rbind continuation so that at the end of with_monad it returns control to the point where the continuation is called. In other words, we need to store a second continuation at the point right after the first continuation is called. Confused yet? Continuations will melt your brain, and this took a little while for met get right:

def with_monad_ext(monad_class, &block)
  finishers = []

  rbind_ext = lambda do |monad|
    begin
      checked_callcc do |outer_cont|
        monad.bindb do |val|
          callcc do |inner_cont|
            finishers.push(inner_cont)
            outer_cont.call(val)
          end
        end
      end
    rescue ContinuationUnused => unused
      raise MonadEscape.new(unused.result)
    end
  end

  val = begin
    monad_class.maybeUnit(block.call(rbind_ext))
  rescue MonadEscape => escape
    escape.monad
  end

  finisher = finishers.pop()
  if finisher
    finisher.call(val)
  else
    val
  end
end

list3 = with_monad_ext ListMonad do |rbind|
  x = rbind.call    [1,2,3]
  y = rbind.call [10,20,30]
  x + y
end

Success! list3 is correctly ListMonad([11, 21, 31, 12, 22, 32, 13, 23, 33]). The only real problem is that the normal rbind won't work. We have to make a special one and give it to the block as an argument. This wouldn't be that bad, except that Ruby has this silly limitation so we have to say rbind.call rather than just rbind, which makes it less fun to type. I named it with_monad_ext because I don't think it will be needed as often as with_monad, and with_monad is more convenient.

Now, back to python. The problem is harder. It's rooted in the fact that for a given iterator, itr.send() is not reentrant. But, if we could copy the iterator, that would give us a possible solution. I did a little googling and found that python isn't going to have copy.copy(itr) anytime soon because no one cares about them. Some people have made some progress, but it segfaults on my computer (64-bit Linux, in case you're wondering).

So, I came up with this terrible little hack which makes it possible to "copy" iterators:

class CopyableIterator:
    def __init__(self, generator, log = ()):
        self.generator = generator
        self.log       = list(log) #hmmm... if the logs were immutable, we wouldn't have to do this
        self.iterator  = None

    def getIterator(self):
        if self.iterator is None:
            self.iterator = self.generator()
            for value in self.log:
                self.iterator.send(value)
        return self.iterator
            
    def send(self, value):
        iterator = self.getIterator()
        self.log.append(value)
        return iterator.send(value)

    def next(self):
        return self.send(None)

    def copy(self):
        return self.__class__(self.generator, self.log)

That let's us define @do_ext:

@decorator_with_args
def do_ext(func, func_args, func_kargs, Monad):
    @handle_monadic_throws(Monad)
    def run_maybe_iterator():
        itr = func(*func_args, **func_kargs)

        if isinstance(itr, types.GeneratorType):
            @handle_monadic_throws(Monad)
            def send(itr, val):
                try:
                    # here's the real magic
                    monad = Monad.maybeNew(itr.send(val))
                    return monad.bind(lambda val : send(itr.copy(), val))
                except StopIteration:
                    return Monad.unit(None)
                
            return send(CopyableIterator(lambda : func(*func_args, **func_kargs)), None)
        else:
            #not really a generator
            if itr is None:
                return Monad.unit(None)
            else:
                return itr

    return run_maybe_iterator()

@do_ext(ListMonad)
def with_list_monad_ext():
    x = yield    [1, 2, 3]
    y = yield [10, 20, 30]
    mreturn(x + y)

list3 = with_list_monad_ext()

And list3 is ListMonad([11, 21, 31, 12, 22, 32, 13, 23]). Success again! There is a downside, though. First, all of the calculations are done for all of the combinations. This is especially bad if you do this:

@do_ext(ListMonad)
def with_list_monad_ext():
    print "foo"
    x = yield    [1, 2, 3]
    y = yield [10, 20, 30]
    mreturn(x + y)

Because now you'll print "foo" 9 times instead of 1.

So there you have it: you can do the List monad with do syntax in python and ruby, but you'll have to decide whether the trade-offs are worth it.

1 comment:

  1. Does it go up to 3 or more invocations? I'm having difficulty getting all your code together to try it myself.

    For the record, here's my ParseTree version, which now supports the arrow "operator" and is nestable:

    http://repo.or.cz/w/ruby-do-notation.git

    ReplyDelete

Blog Archive

Google Analytics