from vowpalwabbit import pyvw
# the label for each word is its parent, or -1 for root
my_dataset = [
[
("the", 1), # 0
("monster", 2), # 1
("ate", -1), # 2
("a", 5), # 3
("big", 5), # 4
("sandwich", 2),
], # 5
[("the", 1), ("sandwich", 2), ("is", -1), ("tasty", 2)], # 0 # 1 # 2 # 3
[("a", 1), ("sandwich", 2), ("ate", -1), ("itself", 2)], # 0 # 1 # 2 # 3
]
class CovingtonDepParser(pyvw.SearchTask):
def __init__(self, vw, sch, num_actions):
pyvw.SearchTask.__init__(self, vw, sch, num_actions)
sch.set_options(sch.AUTO_HAMMING_LOSS | sch.AUTO_CONDITION_FEATURES)
def _run(self, sentence):
N = len(sentence)
# initialize our output so everything is a root
output = [-1 for i in range(N)]
for n in range(N):
wordN, parN = sentence[n]
for m in range(-1, N):
if m == n:
continue
wordM = sentence[m][0] if m > 0 else "*root*"
# ask the question: is m the parent of n?
isParent = 2 if m == parN else 1
# construct an example
dir = "l" if m < n else "r"
ex = self.vw.example(
{
"a": [wordN, dir + "_" + wordN],
"b": [wordM, dir + "_" + wordN],
"p": [wordN + "_" + wordM, dir + "_" + wordN + "_" + wordM],
"d": [
str(m - n <= d) + "<=" + str(d)
for d in [-8, -4, -2, -1, 1, 2, 4, 8]
]
+ [
str(m - n >= d) + ">=" + str(d)
for d in [-8, -4, -2, -1, 1, 2, 4, 8]
],
}
)
pred = self.sch.predict(
examples=ex,
my_tag=(m + 1) * N + n + 1,
oracle=isParent,
condition=[
(max(0, (m) * N + n + 1), "p"),
(max(0, (m + 1) * N + n), "q"),
],
)
vw.finish_example(
[ex]
) # must pass the example in as a list because search is a MultiEx reduction
if pred == 2:
output[n] = m
break
return output
class CovingtonDepParserLDF(pyvw.SearchTask):
def __init__(self, vw, sch, num_actions):
pyvw.SearchTask.__init__(self, vw, sch, num_actions)
sch.set_options(
sch.AUTO_HAMMING_LOSS | sch.IS_LDF | sch.AUTO_CONDITION_FEATURES
)
def makeExample(self, sentence, n, m):
wordN = sentence[n][0]
wordM = sentence[m][0] if m >= 0 else "*ROOT*"
dir = "l" if m < n else "r"
ex = self.vw.example(
{
"a": [wordN, dir + "_" + wordN],
"b": [wordM, dir + "_" + wordM],
"p": [wordN + "_" + wordM, dir + "_" + wordN + "_" + wordM],
"d": [
str(m - n <= d) + "<=" + str(d)
for d in [-8, -4, -2, -1, 1, 2, 4, 8]
]
+ [
str(m - n >= d) + ">=" + str(d)
for d in [-8, -4, -2, -1, 1, 2, 4, 8]
],
},
labelType=self.vw.lCostSensitive,
)
# Label identifies this action. The value doesn't affect search
# oracle/prediction (which use 0-based example indices), but it
# must be a valid positive integer for csoaa_ldf.
ex.set_label_string(str(m + 2) + ":0")
return ex
def _run(self, sentence):
N = len(sentence)
# initialize our output so everything is a root
output = [-1 for i in range(N)]
for n in range(N):
# make LDF examples
examples = []
for m in range(-1, N):
if n != m:
examples.append(self.makeExample(sentence=sentence, n=n, m=m))
# truth
parN = sentence[n][1]
# In LDF mode, oracle is the 0-based index of the correct
# example in the examples list.
# Examples are ordered: m=-1, 0, ..., n-1, n+1, ..., N-1
# So the index of m=parN is:
# parN + 1 when parN < n (m=-1 is at index 0, m=0 at 1, ...)
# parN when parN > n (m=n is skipped)
oracle = parN + 1 if parN < n else parN
# make a prediction
pred = self.sch.predict(
examples=examples,
my_tag=n + 1,
oracle=oracle,
condition=[(n, "p"), (n - 1, "q")],
)
vw.finish_example(examples)
# Reverse mapping: convert 0-based example index back to m
# index 0..n => m = index - 1 (i.e., m=-1, 0, ..., n-1)
# index n+1..N-1 => m = index (i.e., m=n+1, ..., N-1)
output[n] = pred - 1 if pred <= n else pred
return output
# TODO: if they make sure search=0 <==> ldf <==> csoaa_ldf
# demo the non-ldf version:
print("training non-LDF")
vw = pyvw.Workspace("--search 2 --search_task hook", quiet=True)
task = vw.init_search_task(CovingtonDepParser)
for p in range(2): # do two passes over the training data
task.learn(my_dataset)
print("testing non-LDF")
print(task.predict([(w, -1) for w in "the monster ate a sandwich".split()]))
print("should have printed [ 1 2 -1 4 2 ]")
print("training LDF")
vw = pyvw.Workspace("--search 0 --csoaa_ldf m --search_task hook", quiet=True)
task = vw.init_search_task(CovingtonDepParserLDF)
for p in range(100): # do two passes over the training data
task.learn(my_dataset)
print("testing LDF")
print(task.predict([(w, -1) for w in "the monster ate a sandwich".split()]))
print("should have printed [ 1 2 -1 4 2 ]")