Update ristretto.sage to match goldilocks b0af87

This commit is contained in:
Henry de Valence 2018-01-16 14:30:15 -08:00
parent de377290ee
commit 59837c6ecf

390
vendor/ristretto.sage vendored
View File

@ -49,7 +49,11 @@ def isqrt(x,exn=InvalidEncodingException("Not on curve")):
"""Return 1/sqrt(x)"""
if x==0: return 0
if not is_square(x): raise exn
return 1/sqrt(x)
s = sqrt(x)
#if negative(s): s=-s
return 1/s
def inv0(x): return 1/x if x != 0 else 0
def isqrt_i(x):
"""Return 1/sqrt(x) or 1/sqrt(zeta * x)"""
@ -117,16 +121,20 @@ class QuotientEdwardsPoint(object):
else:
return self.__class__(-self.x, -self.y)
def doubleAndEncodeSpec(self):
return (self+self).encode()
# Utility functions
@classmethod
def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False):
def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False,maskHiBits=False):
"""Convert little-endian bytes to field element, sanity check length"""
if len(bytes) != cls.encLen:
raise InvalidEncodingException("wrong length %d" % len(bytes))
s = dec_le(bytes)
if mustBeProper and s >= cls.F.modulus():
if mustBeProper and s >= cls.F.order():
raise InvalidEncodingException("%d out of range!" % s)
bitlen = int(ceil(log(cls.F.order())/log(2)))
if maskHiBits: s &= 2^bitlen-1
s = cls.F(s)
if mustBePositive and negative(s):
raise InvalidEncodingException("%d is negative!" % s)
@ -197,7 +205,42 @@ class RistrettoPoint(QuotientEdwardsPoint):
if negative(isr^2*num*y*t): y = -y
s = isr*y*(z-y)
return self.gfToBytes(s,mustBePositive=True)
@optimized_version_of("doubleAndEncodeSpec")
def doubleAndEncode(self):
X,Y,Z,T = self.xyzt()
a,d,mneg = self.a,self.d,self.mneg
if self.cofactor==8:
e = 2*X*Y
f = Z^2+d*T^2
g = Y^2-a*X^2
h = Z^2-d*T^2
inv1 = 1/(e*f*g*h)
z_inv = inv1*e*g # 1 / (f*h)
t_inv = inv1*f*h
if negative(e*g*z_inv):
if a==-1: sqrta = self.i
else: sqrta = -1
e,f,g,h = g,h,-e,f*sqrta
factor = self.i
else:
factor = self.magic
if negative(h*e*z_inv): g=-g
s = (h-g)*factor*g*t_inv
else:
foo = Y^2+a*X^2
bar = X*Y
den = 1/(foo*bar)
if negative(2*bar^2*den): tmp = a*X^2
else: tmp = Y^2
s = self.magic*(Z^2-tmp)*foo*den
return self.gfToBytes(s,mustBePositive=True)
@classmethod
@ -238,8 +281,9 @@ class RistrettoPoint(QuotientEdwardsPoint):
@classmethod
def elligatorSpec(cls,r0):
a,d = cls.a,cls.d
r = cls.qnr * cls.bytesToGf(r0)^2
r = cls.qnr * cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)^2
den = (d*r-a)*(a*r-d)
if den == 0: return cls()
n1 = cls.a*(r+1)*(a+d)*(d-a)/den
n2 = r*n1
if is_square(n1):
@ -253,7 +297,7 @@ class RistrettoPoint(QuotientEdwardsPoint):
@optimized_version_of("elligatorSpec")
def elligator(cls,r0):
a,d = cls.a,cls.d
r0 = cls.bytesToGf(r0)
r0 = cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)
r = cls.qnr * r0^2
den = (d*r-a)*(a*r-d)
num = cls.a*(r+1)*(a+d)*(d-a)
@ -278,15 +322,11 @@ class Decaf_1_1_Point(QuotientEdwardsPoint):
if self.cofactor==8 and negative(x*y*self.isoMagic):
x,y = self.torque()
isr2 = isqrt(a*(y^2-1)) * sqrt(a*d-1)
sr = xsqrt(1-a*x^2)
assert sr in [isr2*x*y,-isr2*x*y]
altx = 1/isr2*self.isoMagic
if negative(altx): s = (1+x*y*isr2)/(a*x)
else: s = (1-x*y*isr2)/(a*x)
altx = x*y*self.isoMagic / sr
if negative(altx): s = (1+sr)/x
else: s = (1-sr)/x
return self.gfToBytes(s,mustBePositive=True)
@ -297,52 +337,141 @@ class Decaf_1_1_Point(QuotientEdwardsPoint):
s = cls.bytesToGf(s,mustBePositive=True)
if s==0: return cls()
isr = isqrt(s^4 + 2*(a-2*d)*s^2 + 1)
altx = 2*s*isr*cls.isoMagic
if negative(altx): isr = -isr
t = xsqrt(s^4 + 2*(a-2*d)*s^2 + 1)
altx = 2*s*cls.isoMagic/t
if negative(altx): t = -t
x = 2*s / (1+a*s^2)
y = (1-a*s^2) * isr
y = (1-a*s^2) / t
if cls.cofactor==8 and (negative(x*y*cls.isoMagic) or y==0):
raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
return cls(x,y)
@optimized_version_of("encodeSpec")
def encode(self):
"""Encode, optimized version"""
def toJacobiQuartic(self,toggle_rotation=False,toggle_altx=False,toggle_s=False):
"Return s,t on jacobi curve"
a,d = self.a,self.d
x,y,z,t = self.xyzt()
if self.cofactor == 8:
# Cofactor 8 version
# Simulate IMAGINE_TWIST because that's how libdecaf does it
x = self.i*x
t = self.i*t
a = -a
d = -d
# OK, the actual libdecaf code should be here
num = (z+y)*(z-y)
den = x*y
tmp = isqrt(num*(a-d)*den^2)
if negative(tmp^2*den*num*(a-d)*t^2*self.isoMagic):
den,num = num,den
tmp *= sqrt(a-d) # witness that cofactor is 8
yisr = x*sqrt(a)
toggle = (a==1)
else:
yisr = y*(a*d-1)
toggle = False
isr = isqrt(num*(a-d)*den^2)
iden = isr * den * self.isoMagic # 1/sqrt((z+y)(z-y)) = 1/sqrt(1-Y^2) / z
inum = isr * num # sqrt(1-Y^2) * z / xysqrt(a-d) ~ 1/sqrt(1-ax^2)/z
tiisr = tmp*num
altx = tiisr*t*self.isoMagic
if negative(altx) != toggle: tiisr =- tiisr
s = tmp*den*yisr*(tiisr*z - 1)
if negative(iden*inum*self.i*t^2*(d-a)) != toggle_rotation:
iden,inum = inum,iden
fac = x*sqrt(a)
toggle=(a==-1)
else:
fac = y
toggle=False
imi = self.isoMagic * self.i
altx = inum*t*imi
neg_altx = negative(altx) != toggle_altx
if neg_altx != toggle: inum =- inum
tmp = fac*(inum*z + 1)
s = iden*tmp*imi
negm1 = (negative(s) != toggle_s) != neg_altx
if negm1: m1 = a*fac + z
else: m1 = a*fac - z
swap = toggle_s
else:
# Much simpler cofactor 4 version
num = (x+t)*(x-t)
isr = isqrt(num*(a-d)*x^2)
ratio = isr*num
if negative(ratio*self.isoMagic): ratio=-ratio
s = (a-d)*isr*x*(ratio*z - t)
ratio = isr*num
altx = ratio*self.isoMagic
neg_altx = negative(altx) != toggle_altx
if neg_altx: ratio =- ratio
tmp = ratio*z - t
s = (a-d)*isr*x*tmp
negx = (negative(s) != toggle_s) != neg_altx
if negx: m1 = -a*t + x
else: m1 = -a*t - x
swap = toggle_s
if negative(s): s = -s
return self.gfToBytes(s,mustBePositive=True)
return s,m1,a*tmp,swap
def invertElligator(self,toggle_r=False,*args,**kwargs):
"Produce preimage of self under elligator, or None"
a,d = self.a,self.d
rets = []
tr = [False,True] if self.cofactor == 8 else [False]
for toggle_rotation in tr:
for toggle_altx in [False,True]:
for toggle_s in [False,True]:
for toggle_r in [False,True]:
s,m1,m12,swap = self.toJacobiQuartic(toggle_rotation,toggle_altx,toggle_s)
#print
#print toggle_rotation,toggle_altx,toggle_s
#print m1
#print m12
if self == self.__class__():
if self.cofactor == 4:
# Hacks for identity!
if toggle_altx: m12 = 1
elif toggle_s: m1 = 1
elif toggle_r: continue
## BOTH???
else:
m12 = 1
imi = self.isoMagic * self.i
if toggle_rotation:
if toggle_altx: m1 = -imi
else: m1 = +imi
else:
if toggle_altx: m1 = 0
else: m1 = a-d
rnum = (d*a*m12-m1)
rden = ((d*a-1)*m12+m1)
if swap: rnum,rden = rden,rnum
ok,sr = isqrt_i(rnum*rden*self.qnr)
if not ok: continue
sr *= rnum
#print "Works! %d %x" % (swap,sr)
if negative(sr) != toggle_r: sr = -sr
ret = self.gfToBytes(sr)
if self.elligator(ret) != self and self.elligator(ret) != -self:
print "WRONG!",[toggle_rotation,toggle_altx,toggle_s]
if self.elligator(ret) == -self and self != -self: print "Negated!",[toggle_rotation,toggle_altx,toggle_s]
rets.append(bytes(ret))
return rets
@optimized_version_of("encodeSpec")
def encode(self):
"""Encode, optimized version"""
return self.gfToBytes(self.toJacobiQuartic()[0])
@classmethod
@optimized_version_of("decodeSpec")
@ -351,7 +480,7 @@ class Decaf_1_1_Point(QuotientEdwardsPoint):
a,d = cls.a,cls.d
s = cls.bytesToGf(s,mustBePositive=True)
if s==0: return cls()
#if s==0: return cls()
s2 = s^2
den = 1+a*s2
num = den^2 - 4*d*s2
@ -374,13 +503,63 @@ class Decaf_1_1_Point(QuotientEdwardsPoint):
x = 2*s / (1+a*s^2)
y = (1-a*s^2) / t
return cls(x,sgn*y)
@optimized_version_of("doubleAndEncodeSpec")
def doubleAndEncode(self):
X,Y,Z,T = self.xyzt()
a,d = self.a,self.d
if self.cofactor == 8:
# Cofactor 8 version
# Simulate IMAGINE_TWIST because that's how libdecaf does it
X = self.i*X
T = self.i*T
a = -a
d = -d
# TODO: This is only being called for a=-1, so could
# be wrong for a=1
e = 2*X*Y
f = Y^2+a*X^2
g = Y^2-a*X^2
h = Z^2-d*T^2
eim = e*self.isoMagic
inv = 1/(eim*g*f*h)
fh_inv = eim*g*inv*self.i
if negative(eim*g*fh_inv):
idf = g*self.isoMagic*self.i
bar = f
foo = g
test = eim*f
else:
idf = eim
bar = h
foo = -eim
test = g*h
if negative(test*fh_inv): bar =- bar
s = idf*(foo+bar)*inv*f*h
else:
xy = X*Y
h = Z^2-d*T^2
inv = 1/(xy*h)
if negative(inv*2*xy^2*self.isoMagic): tmp = Y
else: tmp = X
s = tmp^2*h*inv # = X/Y or Y/X, interestingly
return self.gfToBytes(s,mustBePositive=True)
@classmethod
def elligatorSpec(cls,r0):
def elligatorSpec(cls,r0,fromR=False):
a,d = cls.a,cls.d
r = cls.qnr * cls.bytesToGf(r0)^2
if fromR: r = r0
else: r = cls.qnr * cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)^2
den = (d*r-(d-a))*((d-a)*r-d)
if den == 0: return cls()
n1 = (r+1)*(a-2*d)/den
n2 = r*n1
if is_square(n1):
@ -394,7 +573,7 @@ class Decaf_1_1_Point(QuotientEdwardsPoint):
@optimized_version_of("elligatorSpec")
def elligator(cls,r0):
a,d = cls.a,cls.d
r0 = cls.bytesToGf(r0)
r0 = cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)
r = cls.qnr * r0^2
den = (d*r-(d-a))*((d-a)*r-d)
num = (r+1)*(a-2*d)
@ -408,6 +587,40 @@ class Decaf_1_1_Point(QuotientEdwardsPoint):
if negative(s) == iss: s = -s
return cls.fromJacobiQuartic(s,t)
def elligatorInverseBruteForce(self):
"""Invert Elligator using SAGE's polynomial solver"""
a,d = self.a,self.d
R.<r0> = self.F[]
r = self.qnr * r0^2
den = (d*r-(d-a))*((d-a)*r-d)
n1 = (r+1)*(a-2*d)/den
n2 = r*n1
ret = set()
for s2,t in [(n1, -(r-1)*(a-2*d)^2 / den - 1),
(n2,r*(r-1)*(a-2*d)^2 / den - 1)]:
x2 = 4*s2/(1+a*s2)^2
y = (1-a*s2) / t
selfT = self
for i in xrange(self.cofactor/2):
xT,yT = selfT
polyX = xT^2-x2
polyY = yT-y
sx = set(r for r,_ in polyX.numerator().roots())
sy = set(r for r,_ in polyY.numerator().roots())
ret = ret.union(sx.intersection(sy))
selfT = selfT.torque()
ret = [self.gfToBytes(r) for r in ret]
for r in ret:
assert self.elligator(r) in [self,-self]
ret = [r for r in ret if self.elligator(r) == self]
return ret
class Ed25519Point(RistrettoPoint):
F = GF(2^255-19)
d = F(-121665/121666)
@ -455,7 +668,7 @@ class IsoEd448Point(RistrettoPoint):
@classmethod
def base(cls):
return cls( # RFC has it wrong
-345397493039729516374008604150537410266655260075183290216406970281645695073672344430481787759340633221708391583424041788924124567700732,
345397493039729516374008604150537410266655260075183290216406970281645695073672344430481787759340633221708391583424041788924124567700732,
-363419362147803445274661903944002267176820680343659030140745099590306164083365386343198191849338272965044442230921818680526749009182718
)
@ -464,7 +677,6 @@ class TwistedEd448GoldilocksPoint(Decaf_1_1_Point):
d = F(-39082)
a = F(-1)
qnr = -1
magic = isqrt(a*d-1)
cofactor = 4
encLen = 56
isoMagic = IsoEd448Point.magic
@ -478,14 +690,13 @@ class Ed448GoldilocksPoint(Decaf_1_1_Point):
d = F(-39081)
a = F(1)
qnr = -1
magic = isqrt(a*d-1)
cofactor = 4
encLen = 56
isoMagic = IsoEd448Point.magic
@classmethod
def base(cls):
return -2*cls( # FIXME: make not negative
return 2*cls(
224580040295924300187604334099896036246789641632564134246125461686950415467406032909029192869357953282578032075146446173674602635247710, 298819210078481492676017930443930673437544040154080242095928241372331506189835876003536878655418784733982303233503462500531545062832660
)
@ -532,19 +743,29 @@ def test(cls,n):
P = cls.base()
print "base", list(P.encode())
for i in xrange(16):
Q = P*i
print i, list(Q.encode())
Q = cls()
for i in xrange(n):
#print i, binascii.hexlify(Q.encode())
QQ = cls.decode(Q.encode())
#print binascii.hexlify(Q.encode())
QE = Q.encode()
QQ = cls.decode(QE)
if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
# Testing s -> 1/s: encodes -point on cofactor
s = cls.bytesToGf(QE)
if s != 0:
ss = cls.gfToBytes(1/s,mustBePositive=True)
try:
QN = cls.decode(ss)
if cls.cofactor == 8:
raise TestFailedException("1/s shouldnt work for cofactor 8")
if QN != -Q:
raise TestFailedException("s -> 1/s should negate point for cofactor 4")
except InvalidEncodingException as e:
# Should be raised iff cofactor==8
if cls.cofactor == 4:
raise TestFailedException("s -> 1/s should work for cofactor 4")
QT = Q
QE = Q.encode()
for h in xrange(cls.cofactor):
QT = QT.torque()
if QT.encode() != QE:
@ -559,27 +780,26 @@ def test(cls,n):
Q2 = Q0*(r+1)
if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
Q = Q1
test(Ed25519Point,100)
#test(NegEd25519Point,100)
#test(IsoEd25519Point,100)
#test(IsoEd448Point,100)
#test(TwistedEd448GoldilocksPoint,100)
#test(Ed448GoldilocksPoint,100)
def testElligator(cls,n):
print "Testing elligator on %s" % cls.__name__
for i in xrange(n):
r = randombytes(cls.encLen)
Q = cls.elligator(r)
print list(r), list(Q.encode())
testElligator(Ed25519Point,100)
#testElligator(NegEd25519Point,100)
#testElligator(IsoEd448Point,100)
#testElligator(Ed448GoldilocksPoint,100)
#testElligator(TwistedEd448GoldilocksPoint,100)
P = cls.elligator(r)
if hasattr(P,"invertElligator"):
iv = P.invertElligator()
modr = bytes(cls.gfToBytes(cls.bytesToGf(r,mustBeProper=False,maskHiBits=True)))
iv2 = P.torque().invertElligator()
if modr not in iv: print "Failed to invert Elligator!"
if len(iv) != len(set(iv)):
print "Elligator inverses not unique!", len(set(iv)), len(iv)
if iv != iv2:
print "Elligator is untorqueable!"
#print [binascii.hexlify(j) for j in iv]
#print [binascii.hexlify(j) for j in iv2]
#break
else:
pass # TODO
def gangtest(classes,n):
print "Gang test",[cls.__name__ for cls in classes]
@ -607,5 +827,31 @@ def gangtest(classes,n):
for c,ret in zip(classes,rets):
print c,binascii.hexlify(ret)
print
gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],100)
gangtest([Ed25519Point,IsoEd25519Point],100)
def testDoubleAndEncode(cls,n):
print "Testing doubleAndEncode on %s" % cls.__name__
for i in xrange(n):
r1 = randombytes(cls.encLen)
r2 = randombytes(cls.encLen)
u = cls.elligator(r1) + cls.elligator(r2)
u.doubleAndEncode()
testDoubleAndEncode(Ed25519Point,100)
testDoubleAndEncode(NegEd25519Point,100)
testDoubleAndEncode(IsoEd25519Point,100)
testDoubleAndEncode(IsoEd448Point,100)
testDoubleAndEncode(TwistedEd448GoldilocksPoint,100)
#test(Ed25519Point,100)
#test(NegEd25519Point,100)
#test(IsoEd25519Point,100)
#test(IsoEd448Point,100)
#test(TwistedEd448GoldilocksPoint,100)
#test(Ed448GoldilocksPoint,100)
#testElligator(Ed25519Point,100)
#testElligator(NegEd25519Point,100)
#testElligator(IsoEd25519Point,100)
#testElligator(IsoEd448Point,100)
#testElligator(Ed448GoldilocksPoint,100)
#testElligator(TwistedEd448GoldilocksPoint,100)
#gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],100)
#gangtest([Ed25519Point,IsoEd25519Point],100)